Skip to content

Commit

Permalink
Feat/new saas billing (#12591)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Jan 12, 2025
1 parent 989fb11 commit d8f57bf
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 3 deletions.
10 changes: 9 additions & 1 deletion api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
enterprise_license_required,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
Expand Down Expand Up @@ -93,6 +98,7 @@ def get(self):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
Expand Down Expand Up @@ -207,6 +213,7 @@ def get(self, dataset_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
Expand Down Expand Up @@ -310,6 +317,7 @@ def patch(self, dataset_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)

Expand Down
10 changes: 10 additions & 0 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
Expand Down Expand Up @@ -230,6 +231,7 @@ def get(self, dataset_id):
@account_initialization_required
@marshal_with(documents_and_batch_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id = str(dataset_id)

Expand Down Expand Up @@ -285,6 +287,7 @@ def post(self, dataset_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
Expand All @@ -308,6 +311,7 @@ class DatasetInitApi(Resource):
@account_initialization_required
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
Expand Down Expand Up @@ -680,6 +684,7 @@ class DocumentProcessingApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
Expand Down Expand Up @@ -716,6 +721,7 @@ class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
Expand Down Expand Up @@ -784,6 +790,7 @@ class DocumentStatusApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
Expand Down Expand Up @@ -879,6 +886,7 @@ class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -911,6 +919,7 @@ class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -940,6 +949,7 @@ class DocumentRetryApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""retry document."""

Expand Down
11 changes: 11 additions & 0 deletions api/controllers/console/datasets/datasets_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
Expand Down Expand Up @@ -106,6 +107,7 @@ def get(self, dataset_id, document_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
Expand Down Expand Up @@ -192,6 +195,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -242,6 +246,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -302,6 +307,7 @@ def patch(self, dataset_id, document_id, segment_id):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -339,6 +345,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -405,6 +412,7 @@ class ChildChunkAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -503,6 +511,7 @@ def get(self, dataset_id, document_id, segment_id):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -546,6 +555,7 @@ class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down Expand Up @@ -590,6 +600,7 @@ def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
Expand Down
7 changes: 6 additions & 1 deletion api/controllers/console/datasets/hit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required


class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
dataset_id_str = str(dataset_id)

Expand Down
33 changes: 32 additions & 1 deletion api/controllers/console/wraps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import os
import time
from functools import wraps

from flask import abort, request
from flask_login import current_user # type: ignore

from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_redis import redis_client
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
Expand Down Expand Up @@ -66,7 +68,9 @@ def decorated(*args, **kwargs):
elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
abort(
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
Expand Down Expand Up @@ -111,6 +115,33 @@ def decorated(*args, **kwargs):
return interceptor


def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"

redis_client.zadd(key, {current_time: current_time})

redis_client.zremrangebyscore(key, 0, current_time - 60000)

request_count = redis_client.zcard(key)

if request_count > knowledge_rate_limit.limit:
abort(
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)

return decorated

return interceptor


def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):
Expand Down
31 changes: 31 additions & 0 deletions api/controllers/service_api/wraps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from collections.abc import Callable
from datetime import UTC, datetime, timedelta
from enum import Enum
Expand All @@ -13,6 +14,7 @@
from werkzeug.exceptions import Forbidden, Unauthorized

from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.model import ApiToken, App, EndUser
Expand Down Expand Up @@ -139,6 +141,35 @@ def decorated(*args, **kwargs):
return interceptor


def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)

if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"

redis_client.zadd(key, {current_time: current_time})

redis_client.zremrangebyscore(key, 0, current_time - 60000)

request_count = redis_client.zcard(key)

if request_count > knowledge_rate_limit.limit:
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)

return decorated

return interceptor


def validate_dataset_token(view=None):
def decorator(view):
@wraps(view)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast

Expand All @@ -19,8 +20,10 @@
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService

from .entities import KnowledgeRetrievalNodeData
from .exc import (
Expand Down Expand Up @@ -61,6 +64,23 @@ def _run(self) -> NodeRunResult:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
)
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)

# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
Expand Down
8 changes: 8 additions & 0 deletions api/services/billing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def get_info(cls, tenant_id: str):
billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info

@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}

knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

return knowledge_rate_limit.get("limit", 10)

@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
Expand Down
Loading

0 comments on commit d8f57bf

Please sign in to comment.