Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/new saas billing #12591

Merged
merged 6 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
enterprise_license_required,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
Expand Down Expand Up @@ -93,6 +98,7 @@ def get(self):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
Expand Down Expand Up @@ -207,6 +213,7 @@ def get(self, dataset_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
Expand Down Expand Up @@ -310,6 +317,7 @@ def patch(self, dataset_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)

Expand Down
10 changes: 10 additions & 0 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
Expand Down Expand Up @@ -230,6 +231,7 @@ def get(self, dataset_id):
@account_initialization_required
@marshal_with(documents_and_batch_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id = str(dataset_id)

Expand Down Expand Up @@ -284,6 +286,7 @@ def post(self, dataset_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
Expand All @@ -307,6 +310,7 @@ class DatasetInitApi(Resource):
@account_initialization_required
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
Expand Down Expand Up @@ -679,6 +683,7 @@ class DocumentProcessingApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
Expand Down Expand Up @@ -715,6 +720,7 @@ class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
Expand Down Expand Up @@ -783,6 +789,7 @@ class DocumentStatusApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
Expand Down Expand Up @@ -878,6 +885,7 @@ class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -910,6 +918,7 @@ class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -939,6 +948,7 @@ class DocumentRetryApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""retry document."""

Expand Down
11 changes: 11 additions & 0 deletions api/controllers/console/datasets/datasets_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
Expand Down Expand Up @@ -106,6 +107,7 @@ def get(self, dataset_id, document_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
Expand Down Expand Up @@ -192,6 +195,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -242,6 +246,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -302,6 +307,7 @@ def patch(self, dataset_id, document_id, segment_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -339,6 +345,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -405,6 +412,7 @@ class ChildChunkAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -503,6 +511,7 @@ def get(self, dataset_id, document_id, segment_id):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -546,6 +555,7 @@ class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -590,6 +600,7 @@ def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down
7 changes: 6 additions & 1 deletion api/controllers/console/datasets/hit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required


class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id_str = str(dataset_id)

Expand Down
33 changes: 32 additions & 1 deletion api/controllers/console/wraps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import os
import time
from functools import wraps

from flask import abort, request
from flask_login import current_user # type: ignore

from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_redis import redis_client
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
Expand Down Expand Up @@ -66,7 +68,9 @@ def decorated(*args, **kwargs):
elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
abort(
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
Expand Down Expand Up @@ -111,6 +115,33 @@ def decorated(*args, **kwargs):
return interceptor


def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"

redis_client.zadd(key, {current_time: current_time})

redis_client.zremrangebyscore(key, 0, current_time - 60000)

request_count = redis_client.zcard(key)

if request_count > knowledge_rate_limit.limit:
abort(
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)

return decorated

return interceptor


def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):
Expand Down
31 changes: 31 additions & 0 deletions api/controllers/service_api/wraps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from collections.abc import Callable
from datetime import UTC, datetime
from enum import Enum
Expand All @@ -11,6 +12,7 @@
from werkzeug.exceptions import Forbidden, Unauthorized

from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.model import ApiToken, App, EndUser
Expand Down Expand Up @@ -137,6 +139,35 @@ def decorated(*args, **kwargs):
return interceptor


def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)

if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"

redis_client.zadd(key, {current_time: current_time})

redis_client.zremrangebyscore(key, 0, current_time - 60000)

request_count = redis_client.zcard(key)

if request_count > knowledge_rate_limit.limit:
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)

return decorated

return interceptor


def validate_dataset_token(view=None):
def decorator(view):
@wraps(view)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast

Expand All @@ -19,8 +20,10 @@
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService

from .entities import KnowledgeRetrievalNodeData
from .exc import (
Expand Down Expand Up @@ -61,6 +64,23 @@ def _run(self) -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
)
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)

# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
Expand Down
8 changes: 8 additions & 0 deletions api/services/billing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def get_info(cls, tenant_id: str):

return billing_info

@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}

knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

return knowledge_rate_limit.get("limit", 10)

@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
Expand Down
Loading
Loading