Skip to content

Commit

Permalink
Merge pull request #1002 from RockChinQ/feat/discord
Browse files Browse the repository at this point in the history
feat: add `discord` adapter
  • Loading branch information
RockChinQ authored Feb 2, 2025
2 parents 12fc76b + 5381e09 commit e5659db
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pkg/core/bootutils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"argon2": "argon2-cffi",
"jwt": "pyjwt",
"Crypto": "pycryptodome",
"lark_oapi": "lark-oapi"
"lark_oapi": "lark-oapi",
"discord": "discord.py"
}


Expand Down
28 changes: 28 additions & 0 deletions pkg/core/migrations/m024_discord_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from .. import migration


@migration.migration_class("discord-config", 24)
class DiscordConfigMigration(migration.Migration):
"""迁移"""

async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""

for adapter in self.ap.platform_cfg.data['platform-adapters']:
if adapter['adapter'] == 'discord':
return False

return True

async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "discord",
"enable": False,
"client_id": "1234567890",
"token": "XXXXXXXXXX"
})

await self.ap.platform_cfg.dump_config()
2 changes: 1 addition & 1 deletion pkg/core/stages/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params, m018_xai_config, m019_zhipuai_config
from ..migrations import m020_wecom_config, m021_lark_config, m022_lmstudio_config, m023_siliconflow_config
from ..migrations import m020_wecom_config, m021_lark_config, m022_lmstudio_config, m023_siliconflow_config, m024_discord_config


@stage.stage_class("MigrationStage")
Expand Down
2 changes: 1 addition & 1 deletion pkg/platform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, ap: app.Application = None):

async def initialize(self):

from .sources import nakuru, aiocqhttp, qqbotpy, wecom, lark
from .sources import nakuru, aiocqhttp, qqbotpy, wecom, lark, discord

async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):

Expand Down
264 changes: 264 additions & 0 deletions pkg/platform/sources/discord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from __future__ import annotations

import discord

import typing
import asyncio
import traceback
import time
import re
import base64
import uuid
import json
import os
import datetime

import aiohttp

from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...utils import image


class DiscordMessageConverter(adapter.MessageConverter):

@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain
) -> typing.Tuple[str, typing.List[discord.File]]:
for ele in message_chain:
if isinstance(ele, platform_message.At):
message_chain.remove(ele)
break

text_string = ""
image_files = []

for ele in message_chain:
if isinstance(ele, platform_message.Image):
image_bytes = None

if ele.base64:
image_bytes = base64.b64decode(ele.base64)
elif ele.url:
async with aiohttp.ClientSession() as session:
async with session.get(ele.url) as response:
image_bytes = await response.read()
elif ele.path:
with open(ele.path, "rb") as f:
image_bytes = f.read()

image_files.append(discord.File(fp=image_bytes, filename=f"{uuid.uuid4()}.png"))
elif isinstance(ele, platform_message.Plain):
text_string += ele.text

return text_string, image_files

@staticmethod
async def target2yiri(
message: discord.Message
) -> platform_message.MessageChain:
lb_msg_list = []

msg_create_time = datetime.datetime.fromtimestamp(
int(message.created_at.timestamp())
)

lb_msg_list.append(
platform_message.Source(id=message.id, time=msg_create_time)
)

element_list = []

def text_element_recur(text_ele: str) -> list[platform_message.MessageComponent]:
if text_ele == "":
return []

# <@1234567890>
# @everyone
# @here
at_pattern = re.compile(r"(@everyone|@here|<@[\d]+>)")
at_matches = at_pattern.findall(text_ele)

if len(at_matches) > 0:
mid_at = at_matches[0]

text_split = text_ele.split(mid_at)

mid_at_component = []

if mid_at == "@everyone" or mid_at == "@here":
mid_at_component.append(platform_message.AtAll())
else:
mid_at_component.append(platform_message.At(target=mid_at[2:-1]))

return text_element_recur(text_split[0]) + \
mid_at_component + \
text_element_recur(text_split[1])
else:
return [platform_message.Plain(text=text_ele)]


element_list.extend(text_element_recur(message.content))

# attachments
for attachment in message.attachments:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(attachment.url) as response:
image_data = await response.read()
image_base64 = base64.b64encode(image_data).decode("utf-8")
image_format = response.headers["Content-Type"]
element_list.append(platform_message.Image(base64=f"data:{image_format};base64,{image_base64}"))

return platform_message.MessageChain(element_list)


class DiscordEventConverter(adapter.EventConverter):

@staticmethod
async def yiri2target(
event: platform_events.Event
) -> discord.Message:
pass

@staticmethod
async def target2yiri(
event: discord.Message
) -> platform_events.Event:
message_chain = await DiscordMessageConverter.target2yiri(event)

if type(event.channel) == discord.DMChannel:
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.author.id,
nickname=event.author.name,
remark=event.channel.id,
),
message_chain=message_chain,
time=event.created_at.timestamp(),
source_platform_object=event,
)
elif type(event.channel) == discord.TextChannel:
return platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=event.author.id,
member_name=event.author.name,
permission=platform_entities.Permission.Member,
group=platform_entities.Group(
id=event.channel.id,
name=event.channel.name,
permission=platform_entities.Permission.Member,
),
special_title="",
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
),
message_chain=message_chain,
time=event.created_at.timestamp(),
source_platform_object=event,
)

@adapter.adapter_class("discord")
class DiscordMessageSourceAdapter(adapter.MessageSourceAdapter):

bot: discord.Client

bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识

config: dict

ap: app.Application

message_converter: DiscordMessageConverter = DiscordMessageConverter()
event_converter: DiscordEventConverter = DiscordEventConverter()

listeners: typing.Dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
] = {}

def __init__(self, config: dict, ap: app.Application):
self.config = config
self.ap = ap

self.bot_account_id = self.config["client_id"]

adapter_self = self

class MyClient(discord.Client):

async def on_message(self: discord.Client, message: discord.Message):
if message.author.id == self.user.id or message.author.bot:
return

lb_event = await adapter_self.event_converter.target2yiri(message)
await adapter_self.listeners[type(lb_event)](lb_event, adapter_self)

intents = discord.Intents.default()
intents.message_content = True

args = {}

if os.getenv("http_proxy"):
args["proxy"] = os.getenv("http_proxy")

self.bot = MyClient(intents=intents, **args)

async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
pass

async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
msg_to_send, image_files = await self.message_converter.yiri2target(message)
assert isinstance(message_source.source_platform_object, discord.Message)

args = {
"content": msg_to_send,
}

if len(image_files) > 0:
args["files"] = image_files

if quote_origin:
args["reference"] = message_source.source_platform_object

if message.has(platform_message.At):
args["mention_author"] = True

await message_source.source_platform_object.channel.send(**args)

async def is_muted(self, group_id: int) -> bool:
return False

def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
):
self.listeners[event_type] = callback

def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
):
self.listeners.pop(event_type)

async def run_async(self):
async with self.bot:
await self.bot.start(self.config["token"], reconnect=True)

async def kill(self) -> bool:
await self.bot.close()
return True
5 changes: 5 additions & 0 deletions pkg/platform/types/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class MessageEvent(Event):
message_chain: platform_message.MessageChain
"""消息内容。"""

source_platform_object: typing.Optional[typing.Any] = None
"""原消息平台对象。
供消息平台适配器开发者使用,如果回复用户时需要使用原消息事件对象的信息,
那么可以将其存到这个字段以供之后取出使用。"""


class FriendMessage(MessageEvent):
"""好友消息。
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ argon2-cffi
pyjwt
pycryptodome
lark-oapi
discord.py

# indirect
taskgroup==0.0.0a4
6 changes: 6 additions & 0 deletions templates/platform.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
"app_id": "cli_abcdefgh",
"app_secret": "XXXXXXXXXX",
"bot_name": "LangBot"
},
{
"adapter": "discord",
"enable": true,
"client_id": "1234567890",
"token": "XXXXXXXXXX"
}
],
"track-function-calls": true,
Expand Down
31 changes: 31 additions & 0 deletions templates/schema/platform.json
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,37 @@
"description": "飞书的bot_name"
}
}
},
{
"title": "Discord 适配器",
"description": "用于接入 Discord",
"properties": {
"adapter": {
"type": "string",
"const": "discord"
},
"enable": {
"type": "boolean",
"default": false,
"description": "是否启用此适配器",
"layout": {
"comp": "switch",
"props": {
"color": "primary"
}
}
},
"client_id": {
"type": "string",
"default": "",
"description": "Discord 的 client_id"
},
"token": {
"type": "string",
"default": "",
"description": "Discord 的 token"
}
}
}
]
}
Expand Down

0 comments on commit e5659db

Please sign in to comment.