Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
OasisAkari committed Nov 29, 2024
2 parents a26b625 + 475dc78 commit 9c6bc2d
Show file tree
Hide file tree
Showing 24 changed files with 2,791 additions and 285 deletions.
4 changes: 2 additions & 2 deletions bots/aiocqhttp/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def as_display(self, text_only=False):
m = html.unescape(self.session.message.message)
if text_only:
m = re.sub(r'\[CQ:text,qq=(.*?)]', r'\1', m)
m = re.sub(CQCodeHandler.pattern, '', m)
m = CQCodeHandler.pattern.sub('', m)
else:
m = CQCodeHandler.pattern.sub(CQCodeHandler.filter_cq, m)
m = CQCodeHandler.filter_cq(m)
m = re.sub(r'\[CQ:at,qq=(.*?)]', fr'{sender_prefix}|\1', m)
m = re.sub(r'\[CQ:json,data=(.*?)]', r'\1', m).replace("\\/", "/")
m = re.sub(r'\[CQ:text,qq=(.*?)]', r'\1', m)
Expand Down
19 changes: 8 additions & 11 deletions bots/aiocqhttp/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import html
import re
from typing import Any, Dict, Optional, Union

import orjson as json

Expand Down Expand Up @@ -31,21 +32,17 @@ class CQCodeHandler:
pattern = re.compile(r'\[CQ:(\w+),[^\]]*\]')

@staticmethod
def filter_cq(match: str):
def filter_cq(s: str) -> str:
"""
过滤CQ码,返回支持的CQ码。
:param match: 正则匹配对象,包含CQ码的字符串消息。
:param s: 正则匹配对象,包含CQ码的字符串消息。
:return: 如果CQ类型在支持列表中,返回原CQ码;否则返回空字符串。
"""
cq_type = match.group(1)
if cq_type in CQCodeHandler.get_supported:
return match.group(0)
else:
return ''
return CQCodeHandler.pattern.sub(lambda m: m.group(0) if m.group(1) in CQCodeHandler.get_supported else '', s)

@staticmethod
def generate_cq(data: dict):
def generate_cq(data: Dict[str, Any]) -> Optional[str]:
"""
生成CQ码字符串。
Expand All @@ -55,15 +52,15 @@ def generate_cq(data: dict):
if 'type' in data and 'data' in data:
cq_type = data['type']
params = data['data']
param_str = [f"{key}={CQCodeHandler.escape_special_char(value)}"
param_str = [f"{key}={CQCodeHandler.escape_special_char(str(value))}"
for key, value in params.items()]
cq_code = f"[CQ:{cq_type}," + ",".join(param_str) + "]"
return cq_code
else:
return None

@staticmethod
def parse_cq(cq_code: str):
def parse_cq(cq_code: str) -> Optional[Dict[str, Union[str, Dict[str, Any]]]]:
"""
解析CQ码字符串,返回包含类型和参数的字典。
Expand Down Expand Up @@ -96,7 +93,7 @@ def parse_cq(cq_code: str):
return data

@staticmethod
def escape_special_char(s, escape_comma: bool = True):
def escape_special_char(s: str, escape_comma: bool = True) -> str:
"""
转义CQ码中的特殊字符。
Expand Down
2 changes: 1 addition & 1 deletion core/builtins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, target_from, target_id, sender_from=None, sender_id=None):
self.session = Session(message=False, target=target_id, sender=sender_id)
self.parent = Bot.MessageSession(self.target, self.session)
if sender_id:
self.parent.target.sender_info = exports.get("BotDBUtil").SenderInfo(f'{sender_from}|{sender_id}')
self.parent.target.sender_id = exports.get("BotDBUtil").SenderInfo(f'{sender_from}|{sender_id}')


Bot.FetchedSession = FetchedSession
Expand Down
125 changes: 83 additions & 42 deletions core/builtins/message/__init__.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions core/builtins/message/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
Logger.error(f"Unexpected message type: {elements}")

@property
def is_safe(self):
def is_safe(self) -> bool:
"""
检查消息链是否安全。
"""
Expand Down Expand Up @@ -190,7 +190,7 @@ def unsafeprompt(name, secret, text):
return False
return True

def as_sendable(self, msg: 'MessageSession' = None, embed=True):
def as_sendable(self, msg: 'MessageSession' = None, embed: bool = True) -> list:
"""
将消息链转换为可发送的格式。
"""
Expand Down Expand Up @@ -234,7 +234,7 @@ def as_sendable(self, msg: 'MessageSession' = None, embed=True):
)
return value

def to_list(self, locale="zh_cn", embed=True, msg: 'MessageSession' = None):
def to_list(self, locale: str = "zh_cn", embed: bool = True, msg: 'MessageSession' = None) -> list:
"""
将消息链转换为列表。
"""
Expand Down Expand Up @@ -322,7 +322,7 @@ def __iadd__(self, other):
return self


def match_kecode(text: str) -> List[Union[Plain, Image, Voice, Embed]]:
def match_kecode(text: str) -> List[Union[Plain, Image, Voice]]:
split_all = re.split(r"(\[Ke:.*?])", text)
split_all = [x for x in split_all if x]
elements = []
Expand Down
14 changes: 12 additions & 2 deletions core/dirty_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def parse_data(result: dict, msg: Bot.MessageSession = None, additional_text=Non

@retry(stop=stop_after_attempt(3), wait=wait_fixed(3))
async def check(*text: Union[str, List[str]], msg: Bot.MessageSession = None, additional_text=None) -> List[Dict]:
'''检查字符串是否合规
'''检查字符串。
:param text: 字符串(List/Union)。
:param msg: 消息会话,若指定则本地化返回的消息。
Expand Down Expand Up @@ -179,14 +179,24 @@ async def check(*text: Union[str, List[str]], msg: Bot.MessageSession = None, ad


async def check_bool(*text: Union[str, List[str]]) -> bool:
'''检查字符串是否合规。
:param text: 字符串(List/Union)。
:returns: 字符串是否合规。
'''
chk = await check(*text)
for x in chk:
if not x['status']:
return True
return False


def rickroll(msg: Bot.MessageSession):
def rickroll(msg: Bot.MessageSession) -> str:
'''合规检查失败时输出的Rickroll消息。
:param msg: 消息会话。
:returns: Rickroll消息。
'''
if Config("enable_rickroll", True) and Config("rickroll_msg", cfg_type=str):
return msg.locale.t_str(Config("rickroll_msg", cfg_type=str))
else:
Expand Down
2 changes: 1 addition & 1 deletion core/joke.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from core.utils.http import url_pattern


def joke(text: str):
def joke(text: str) -> str:
current_date = datetime.now().date()
enable_joke = Config('enable_joke', True, cfg_type=bool)

Expand Down
12 changes: 6 additions & 6 deletions core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ class LoggingLogger:
def __init__(self, name):
self.log = logger
self.log.remove()
self.info = None
self.error = None
self.debug = None
self.warning = None
self.exception = None
self.critical = None
self.info = logger.info
self.error = logger.error
self.debug = logger.debug
self.warning = logger.warning
self.exception = logger.exception
self.critical = logger.critical

self.rename(name)

Expand Down
5 changes: 4 additions & 1 deletion core/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from core.constants.path import cache_path


def random_cache_path():
def random_cache_path() -> str:
'''
提供带有随机UUID文件名的缓存路径。
'''
return join(cache_path, str(uuid.uuid4()))


Expand Down
17 changes: 13 additions & 4 deletions core/utils/cooldown.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
from typing import Dict, Union
from typing import Dict

from core.builtins import MessageSession

_cd_lst: Dict[str, Dict[Union[MessageSession, str], float]] = {}
_cd_lst: Dict[str, Dict[MessageSession, float]] = {}


class CoolDown:

def __init__(self, key: str, msg: Union[MessageSession, str], all: bool = False):
def __init__(self, key: str, msg: MessageSession, all: bool = False):
self.key = key
self.msg = msg
self.sender_id = self.msg
Expand All @@ -19,11 +19,17 @@ def __init__(self, key: str, msg: Union[MessageSession, str], all: bool = False)
self.sender_id = self.sender_id.target.sender_id

def add(self):
'''
添加冷却事件。
'''
if self.key not in _cd_lst:
_cd_lst[self.key] = {}
_cd_lst[self.key][self.sender_id] = datetime.datetime.now().timestamp()

def check(self, delay: int):
def check(self, delay: int) -> float:
'''
检查冷却事件剩余冷却时间。
'''
if self.key not in _cd_lst:
return 0
if self.sender_id in _cd_lst[self.key]:
Expand All @@ -35,6 +41,9 @@ def check(self, delay: int):
return 0

def reset(self):
'''
重置冷却事件。
'''
if self.key in _cd_lst:
if self.sender_id in _cd_lst[self.key]:
_cd_lst[self.key].pop(self.sender_id)
Expand Down
19 changes: 17 additions & 2 deletions core/utils/game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Union
from typing import Any, Optional

from core.logger import Logger
from core.builtins import MessageSession
Expand All @@ -10,7 +10,7 @@


class PlayState:
def __init__(self, game: str, msg: Union[MessageSession, str], all: bool = False):
def __init__(self, game: str, msg: MessageSession, all: bool = False):
self.game = game
self.msg = msg
self.all = all
Expand All @@ -26,6 +26,9 @@ def _get_game_dict(self):
return sender_dict.setdefault(self.game, {'_status': False, '_timestamp': 0.0})

def enable(self) -> None:
'''
开启游戏事件。
'''
game_dict = self._get_game_dict()
game_dict['_status'] = True
game_dict['_timestamp'] = datetime.now().timestamp()
Expand All @@ -35,6 +38,9 @@ def enable(self) -> None:
Logger.info(f'[{self.sender_id}]: Enabled {self.game} at {self.target_id}.')

def disable(self, auto=False) -> None:
'''
关闭游戏事件。
'''
if self.target_id not in playstate_lst:
return
target_dict = playstate_lst[self.target_id]
Expand All @@ -60,6 +66,9 @@ def disable(self, auto=False) -> None:
Logger.info(f'[{self.sender_id}]: Disabled {self.game} at {self.target_id}.')

def update(self, **kwargs) -> None:
'''
更新游戏事件中需要的值。
'''
game_dict = self._get_game_dict()
game_dict.update(kwargs)
if self.all:
Expand All @@ -68,6 +77,9 @@ def update(self, **kwargs) -> None:
Logger.debug(f'[{self.game}]: Updated {str(kwargs)} at {self.sender_id} ({self.target_id}).')

def check(self) -> bool:
'''
检查游戏事件状态,若超过时间则自动关闭。
'''
if self.target_id not in playstate_lst:
return False
target_dict = playstate_lst[self.target_id]
Expand All @@ -83,6 +95,9 @@ def check(self) -> bool:
return status

def get(self, key: str) -> Optional[Any]:
'''
获取游戏事件中需要的值。
'''
if self.target_id not in playstate_lst:
return None
target_dict = playstate_lst[self.target_id]
Expand Down
52 changes: 42 additions & 10 deletions core/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import urllib.parse
import uuid
from http.cookies import SimpleCookie
from typing import Union
from typing import Any, Dict, Optional, Union

import aiohttp
import filetype as ft
Expand Down Expand Up @@ -43,9 +43,16 @@ def private_ip_check(url: str):
f'Attempt of requesting private IP addresses is not allowed, requesting {hostname}.')


async def get_url(url: str, status_code: int = False, headers: dict = None, params: dict = None, fmt: str = None, timeout: int = 20,
async def get_url(url: str,
status_code: int = False,
headers: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
fmt: Optional[str] = None,
timeout: int = 20,
attempt: int = 3,
request_private_ip: bool = False, logging_err_resp: bool = True, cookies: dict = None):
request_private_ip: bool = False,
logging_err_resp: bool = True,
cookies: Optional[Dict[str, Any]] = None) -> Optional[str]:
"""利用AioHttp获取指定url的内容。
:param url: 需要获取的url。
Expand Down Expand Up @@ -104,8 +111,16 @@ async def get_():
return await get_()


async def post_url(url: str, data: any = None, status_code: int = False, headers: dict = None, fmt: str = None, timeout: int = 20,
attempt: int = 3, request_private_ip: bool = False, logging_err_resp: bool = True, cookies: dict = None):
async def post_url(url: str,
data: Any = None,
status_code: int = False,
headers: Optional[Dict[str, Any]] = None,
fmt: Optional[str] = None,
timeout: int = 20,
attempt: int = 3,
request_private_ip: bool = False,
logging_err_resp: bool = True,
cookies: Optional[Dict[str, Any]] = None) -> Optional[str]:
'''利用AioHttp发送POST请求。
:param url: 需要发送的url。
Expand Down Expand Up @@ -163,8 +178,16 @@ async def _post():
return await _post()


async def download(url: str, filename: str = None, path: str = None, status_code: int = False, method: str = "GET", post_data: any = None,
headers: dict = None, timeout: int = 20, attempt: int = 3, request_private_ip: bool = False,
async def download(url: str,
filename: Optional[str] = None,
path: Optional[str] = None,
status_code: int = False,
method: str = "GET",
post_data: Any = None,
headers: Optional[Dict[str, Any]] = None,
timeout: int = 20,
attempt: int = 3,
request_private_ip: bool = False,
logging_err_resp: bool = True) -> Union[str, bool]:
'''利用AioHttp下载指定url的内容,并保存到指定目录。
Expand Down Expand Up @@ -218,10 +241,19 @@ async def download_(filename=filename, path=path):
return await download_()


async def dowanload_to_cache(url: str, filename: str = None, status_code: int = False, method: str = "GET", post_data: any = None,
headers: dict = None, timeout: int = 20, attempt: int = 3, request_private_ip: bool = False,
async def dowanload_to_cache(url: str,
filename: Optional[str] = None,
status_code: int = False,
method: str = "GET",
post_data: Any = None,
headers: Optional[Dict[str, Any]] = None,
timeout: int = 20,
attempt: int = 3,
request_private_ip: bool = False,
logging_err_resp: bool = True) -> Union[str, bool]:
'''下载内容到缓存,仅作兼容用。'''
'''
下载内容到缓存目录,仅作兼容用。
'''
await download(url=url, filename=filename, path=cache_path, status_code=status_code, method=method, post_data=post_data,
headers=headers, timeout=timeout, attempt=attempt, request_private_ip=request_private_ip,
logging_err_resp=logging_err_resp)
Expand Down
Loading

0 comments on commit 9c6bc2d

Please sign in to comment.