Skip to content

Commit

Permalink
Refactor(sub-endpoints)!: refactor sub endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
erfjab committed Aug 18, 2024
1 parent 36532a9 commit 215bf32
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 132 deletions.
34 changes: 23 additions & 11 deletions app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Optional, Union
from app.models.admin import AdminInDB, AdminValidationResult
from app.models.user import UserResponse
from app.db import Session, crud, get_db
from config import SUDOERS
from fastapi import Depends, HTTPException
from datetime import datetime

from app.utils.jwt import get_subscription_payload

def validate_admin(db: Session, username: str, password: str) -> Optional[AdminValidationResult]:
"""
Validate admin credentials with environment variables or database.
"""
"""Validate admin credentials with environment variables or database."""
if SUDOERS.get(username) == password:
return AdminValidationResult(username, True)

Expand All @@ -21,9 +20,7 @@ def validate_admin(db: Session, username: str, password: str) -> Optional[AdminV


def get_admin_by_username(username: str, db: Session = Depends(get_db)):
"""
Fetch an admin by username from the database.
"""
"""Fetch an admin by username from the database."""
dbadmin = crud.get_admin(db, username)
if not dbadmin:
raise HTTPException(status_code=404, detail="Admin not found")
Expand All @@ -37,9 +34,7 @@ def get_dbnode(node_id: int, db: Session = Depends(get_db)):
return dbnode

def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> bool:
"""
Validate if start and end dates are correct and if end is after start.
"""
"""Validate if start and end dates are correct and if end is after start."""
try:
if start:
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start)
Expand All @@ -56,4 +51,21 @@ def get_user_template(template_id: int, db: Session = Depends(get_db)):
dbuser_template = crud.get_user_template(db, template_id)
if not dbuser_template:
raise HTTPException(status_code=404, detail="User Template not found")
return dbuser_template
return dbuser_template

def get_validated_user(
token: str,
db: Session = Depends(get_db)
) -> UserResponse:
sub = get_subscription_payload(token)
if not sub:
raise HTTPException(status_code=204, detail="Invalid subscription token")

dbuser = crud.get_user(db, sub['username'])
if not dbuser or dbuser.created_at > sub['created_at']:
raise HTTPException(status_code=204, detail="User not found or invalid creation date")

if dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub['created_at']:
raise HTTPException(status_code=204, detail="Subscription has been revoked")

return dbuser
173 changes: 52 additions & 121 deletions app/routers/subscription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from datetime import datetime
from datetime import datetime, timedelta
from distutils.version import LooseVersion

from fastapi import Depends, Header, HTTPException, Path, Request, Response, APIRouter
Expand All @@ -9,7 +9,7 @@
from app.models.user import SubscriptionUserResponse, UserResponse
from app.subscription.share import encode_title, generate_subscription
from app.templates import render_template
from app.utils.jwt import get_subscription_payload
from app.dependencies import get_validated_user, validate_dates
from config import (
SUB_PROFILE_TITLE,
SUB_SUPPORT_URL,
Expand All @@ -24,38 +24,29 @@

router = APIRouter(tags=['Subscription'], prefix=f'/{XRAY_SUBSCRIPTION_PATH}')

@router.get("/{token}/")
@router.get("/{token}", include_in_schema=False)
def user_subscription(token: str,
request: Request,
db: Session = Depends(get_db),
user_agent: str = Header(default="")):
"""
Subscription link, V2ray and Clash supported
"""
accept_header = request.headers.get("Accept", "")

def get_subscription_user_info(user: UserResponse) -> dict:
return {
"upload": 0,
"download": user.used_traffic,
"total": user.data_limit,
"expire": user.expire,
}

sub = get_subscription_payload(token)
if not sub:
return Response(status_code=204)

dbuser = crud.get_user(db, sub['username'])
if not dbuser or dbuser.created_at > sub['created_at']:
return Response(status_code=204)
def get_subscription_user_info(user: UserResponse) -> dict:
"""Retrieve user subscription information including upload, download, total data, and expiry."""
return {
"upload": 0,
"download": user.used_traffic,
"total": user.data_limit if user.data_limit is not None else 0,
"expire": user.expire if user.expire is not None else 0,
}

if dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub['created_at']:
return Response(status_code=204)

@router.get("/{token}/")
@router.get("/{token}", include_in_schema=False)
def user_subscription(
request: Request,
db: Session = Depends(get_db),
dbuser: UserResponse = Depends(get_validated_user),
user_agent: str = Header(default="")
):
"""Provides a subscription link based on the user agent (Clash, V2Ray, etc.)."""
user: UserResponse = UserResponse.from_orm(dbuser)
crud.update_user_sub(db, dbuser, user_agent)

accept_header = request.headers.get("Accept", "")
if "text/html" in accept_header:
return HTMLResponse(
render_template(
Expand All @@ -77,8 +68,6 @@ def get_subscription_user_info(user: UserResponse) -> dict:
)
}

crud.update_user_sub(db, dbuser, user_agent)

if re.match('^([Cc]lash-verge|[Cc]lash[-\.]?[Mm]eta|[Ff][Ll][Cc]lash|[Mm]ihomo)', user_agent):
conf = generate_subscription(user=user, config_format="clash-meta", as_base64=False, reverse=False)
return Response(content=conf, media_type="text/yaml", headers=response_headers)
Expand Down Expand Up @@ -106,7 +95,6 @@ def get_subscription_user_info(user: UserResponse) -> dict:

elif (USE_CUSTOM_JSON_DEFAULT or USE_CUSTOM_JSON_FOR_V2RAYNG) and re.match('^v2rayNG/(\d+\.\d+\.\d+)', user_agent):
version_str = re.match('^v2rayNG/(\d+\.\d+\.\d+)', user_agent).group(1)
# i don't know what is wrong with v2rayng and these recent changes
if LooseVersion(version_str) >= LooseVersion("1.8.29"):
conf = generate_subscription(user=user, config_format="v2ray-json", as_base64=False, reverse=False)
return Response(content=conf, media_type="application/json", headers=response_headers)
Expand All @@ -131,48 +119,26 @@ def get_subscription_user_info(user: UserResponse) -> dict:


@router.get("/{token}/info", response_model=SubscriptionUserResponse)
def user_subscription_info(token: str,
db: Session = Depends(get_db)):
sub = get_subscription_payload(token)
if not sub:
return Response(status_code=404)

dbuser = crud.get_user(db, sub['username'])
if not dbuser or dbuser.created_at > sub['created_at']:
return Response(status_code=404)

elif dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub['created_at']:
return Response(status_code=404)

def user_subscription_info(
dbuser: UserResponse = Depends(get_validated_user),
):
"""Retrieves detailed information about the user's subscription."""
return dbuser


@router.get("/{token}/usage")
def user_get_usage(token: str,
start: str = None,
end: str = None,
db: Session = Depends(get_db)):

sub = get_subscription_payload(token)
if not sub:
return Response(status_code=204)

dbuser = crud.get_user(db, sub['username'])
if not dbuser or dbuser.created_at > sub['created_at']:
return Response(status_code=204)

if dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub['created_at']:
return Response(status_code=204)

if start is None:
start_date = datetime.utcfromtimestamp(datetime.utcnow().timestamp() - 30 * 24 * 3600)
else:
start_date = datetime.fromisoformat(start)
def user_get_usage(
dbuser: UserResponse = Depends(get_validated_user),
start: str = None,
end: str = None,
db: Session = Depends(get_db)
):
"""Fetches the usage statistics for the user within a specified date range."""
if not validate_dates(start, end):
raise HTTPException(status_code=400, detail="Invalid date range or format")

if end is None:
end_date = datetime.utcnow()
else:
end_date = datetime.fromisoformat(end)
start_date = start if start else datetime.utcnow() - timedelta(days=30)
end_date = end if end else datetime.utcnow()

usages = crud.get_user_usages(db, dbuser, start_date, end_date)

Expand All @@ -181,36 +147,15 @@ def user_get_usage(token: str,

@router.get("/{token}/{client_type}")
def user_subscription_with_client_type(
token: str,
request: Request,
dbuser: UserResponse = Depends(get_validated_user),
client_type: str = Path(..., regex="sing-box|clash-meta|clash|outline|v2ray|v2ray-json"),
db: Session = Depends(get_db),
user_agent: str = Header(default="")
):
"""
Subscription link, v2ray, clash, sing-box, outline and clash-meta supported
"""

def get_subscription_user_info(user: UserResponse) -> dict:
return {
"upload": 0,
"download": user.used_traffic,
"total": user.data_limit if user.data_limit is not None else 0,
"expire": user.expire if user.expire is not None else 0,
}

sub = get_subscription_payload(token)
if not sub:
return Response(status_code=204)

dbuser = crud.get_user(db, sub['username'])
if not dbuser or dbuser.created_at > sub['created_at']:
return Response(status_code=204)

if dbuser.sub_revoked_at and dbuser.sub_revoked_at > sub['created_at']:
return Response(status_code=204)

"""Provides a subscription link based on the specified client type (e.g., Clash, V2Ray)."""
user: UserResponse = UserResponse.from_orm(dbuser)
crud.update_user_sub(db, dbuser, user_agent)

response_headers = {
"content-disposition": f'attachment; filename="{user.username}"',
Expand All @@ -223,32 +168,18 @@ def get_subscription_user_info(user: UserResponse) -> dict:
for key, val in get_subscription_user_info(user).items()
)
}

crud.update_user_sub(db, dbuser, user_agent)

if client_type == "clash-meta":
conf = generate_subscription(user=user, config_format="clash-meta", as_base64=False, reverse=False)
return Response(content=conf, media_type="text/yaml", headers=response_headers)

elif client_type == "sing-box":
conf = generate_subscription(user=user, config_format="sing-box", as_base64=False, reverse=False)
return Response(content=conf, media_type="application/json", headers=response_headers)

elif client_type == "clash":
conf = generate_subscription(user=user, config_format="clash", as_base64=False, reverse=False)
return Response(content=conf, media_type="text/yaml", headers=response_headers)

elif client_type == "v2ray":
conf = generate_subscription(user=user, config_format="v2ray", as_base64=True, reverse=False)
return Response(content=conf, media_type="text/plain", headers=response_headers)

elif client_type == "outline":
conf = generate_subscription(user=user, config_format="outline", as_base64=False, reverse=False)
return Response(content=conf, media_type="application/json", headers=response_headers)

elif client_type == "v2ray-json":
conf = generate_subscription(user=user, config_format="v2ray-json", as_base64=False, reverse=False)
return Response(content=conf, media_type="application/json", headers=response_headers)

client_config = {
"clash-meta": {"config_format": "clash-meta", "media_type": "text/yaml", "as_base64": False, "reverse": False},
"sing-box": {"config_format": "sing-box", "media_type": "application/json", "as_base64": False, "reverse": False},
"clash": {"config_format": "clash", "media_type": "text/yaml", "as_base64": False, "reverse": False},
"v2ray": {"config_format": "v2ray", "media_type": "text/plain", "as_base64": True, "reverse": False},
"outline": {"config_format": "outline", "media_type": "application/json", "as_base64": False, "reverse": False},
"v2ray-json": {"config_format": "v2ray-json", "media_type": "application/json", "as_base64": False, "reverse": False}
}

if client_type in client_config:
config = client_config[client_type]
conf = generate_subscription(user=user, **config)
return Response(content=conf, media_type=config["media_type"], headers=response_headers)
else:
raise HTTPException(status_code=400, detail="Invalid subscription type")

0 comments on commit 215bf32

Please sign in to comment.