Skip to content

Commit

Permalink
Merge pull request revoltchat#8 from psychopurp/dev
Browse files Browse the repository at this point in the history
feat(core/database/delta): onboard preparation &  request bot-server for creating prompt bot
  • Loading branch information
psychopurp authored Aug 31, 2023
2 parents 84bddbd + 849161e commit e465f00
Show file tree
Hide file tree
Showing 14 changed files with 391 additions and 26 deletions.
8 changes: 8 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ AUTUMN_PUBLIC_URL=http://local.revolt.chat:3000
# URL to where January is publicly available
JANUARY_PUBLIC_URL=http://local.revolt.chat:7000

# URL to bot server
BOT_SERVER_PUBLIC_URL=http://bot.server.com

# GPT-4 GPT-3
OFFICIAL_MODEL_BOTS=01H7A1NB3P10J15DFEPQ7RN67J,01H6ZWPCCKQ4J46D088HBY5ZP4
# AIcoder Autoinstall
OFFICIAL_CUSTOM_BOTS=01H8DF1Y5CF0BA28D2TKBSSFGV

# URL to where Vortex is publicly available
# VOSO_PUBLIC_URL=https://voso.revolt.chat

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ pub async fn run_migrations(db: &MongoDb, revision: i32) -> i32 {
doc! {},
doc! {
"$set": {
"joined_at": DateTime::now().to_rfc3339_string()
"joined_at": DateTime::now().try_to_rfc3339_string().unwrap()
}
},
None,
Expand Down
5 changes: 5 additions & 0 deletions crates/core/database/src/models/bots/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ auto_derived_partial!(
/// Enum of bot flags
#[serde(skip_serializing_if = "Option::is_none")]
pub flags: Option<i32>,

/// Bot server invite code
#[serde(skip_serializing_if = "Option::is_none")]
pub server_invite: Option<String>,
},
"PartialBot"
);
Expand Down Expand Up @@ -82,6 +86,7 @@ impl Default for Bot {
privacy_policy_url: Default::default(),
flags: Default::default(),
bot_type: Default::default(),
server_invite: Default::default(),
}
}
}
Expand Down
21 changes: 13 additions & 8 deletions crates/core/database/src/models/channel_invites/model.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use revolt_result::{create_error, Result};

use crate::Database;
use crate::{Channel, Database};

/* static ALPHABET: [char; 54] = [
static ALPHABET: [char; 54] = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J',
'K', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'v', 'w', 'x', 'y', 'z',
]; */
];

auto_derived!(
/// Invite
#[serde(tag = "type")]
pub enum Invite {
/// Invite to a specific server channel
Server {
Expand Down Expand Up @@ -56,28 +57,32 @@ impl Invite {
}

/// Create a new invite from given information
/*pub async fn create_channel_invite(db: &Database, creator_id: String, target: &Channel) -> Result<Invite> {
pub async fn create_channel_invite(
db: &Database,
creator_id: String,
target: &Channel,
) -> Result<Invite> {
let code = nanoid::nanoid!(8, &ALPHABET);
let invite = match &target {
Channel::Group { id, .. } => Ok(Invite::Group {
code,
creator: creator.id.clone(),
creator: creator_id.clone(),
channel: id.clone(),
}),
Channel::TextChannel { id, server, .. } | Channel::VoiceChannel { id, server, .. } => {
Ok(Invite::Server {
code,
creator: creator.id.clone(),
creator: creator_id.clone(),
server: server.clone(),
channel: id.clone(),
})
}
_ => Err(Error::InvalidOperation),
_ => Err(create_error!(InvalidOperation)),
}?;

db.insert_invite(&invite).await?;
Ok(invite)
}*/
}

/// Resolve an invite by its ID or by a public server ID
pub async fn find(db: &Database, code: &str) -> Result<Invite> {
Expand Down
23 changes: 23 additions & 0 deletions crates/core/database/src/models/servers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,29 @@ auto_derived!(
}
);

#[allow(clippy::derivable_impls)]
impl Default for Server {
fn default() -> Self {
Self {
name: Default::default(),
description: Default::default(),
channels: Default::default(),
categories: Default::default(),
system_messages: Default::default(),
roles: Default::default(),
default_permissions: Default::default(),
icon: Default::default(),
banner: Default::default(),
nsfw: Default::default(),
id: Default::default(),
owner: Default::default(),
flags: Default::default(),
analytics: Default::default(),
discoverable: Default::default(),
}
}
}

#[allow(clippy::disallowed_methods)]
impl Server {
/// Create a server
Expand Down
1 change: 1 addition & 0 deletions crates/core/database/src/util/bridge/v0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl From<crate::Bot> for Bot {
privacy_policy_url: value.privacy_policy_url,
flags: value.flags.unwrap_or_default() as u32,
bot_type: value.bot_type.map(|x| x.into()),
server_invite: value.server_invite,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/core/database/src/util/idempotency.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::num::NonZeroUsize;

use revolt_result::{create_error, Error, Result};
use revolt_result::{create_error, Result};

use async_std::sync::Mutex;
use once_cell::sync::Lazy;
Expand Down Expand Up @@ -85,7 +85,7 @@ use rocket::{
#[cfg(feature = "rocket-impl")]
#[async_trait]
impl<'r> FromRequest<'r> for IdempotencyKey {
type Error = Error;
type Error = revolt_result::Error;

async fn from_request(request: &'r rocket::Request<'_>) -> Outcome<Self, Self::Error> {
if let Some(key) = request
Expand Down
4 changes: 4 additions & 0 deletions crates/core/models/src/v0/bots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ auto_derived!(
serde(skip_serializing_if = "crate::if_zero_u32", default)
)]
pub flags: u32,

/// Bot server invite code
#[serde(skip_serializing_if = "Option::is_none")]
pub server_invite: Option<String>,
}

/// Optional fields on bot object
Expand Down
1 change: 1 addition & 0 deletions crates/core/models/src/v0/channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::HashMap;

auto_derived!(
/// Channel
#[cfg_attr(feature = "serde", serde(tag = "channel_type"))]
pub enum Channel {
/// Personal "Saved Notes" channel which allows users to save messages
SavedMessages {
Expand Down
114 changes: 110 additions & 4 deletions crates/delta/src/routes/bots/create.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
use revolt_database::{Bot, BotType, Database, PartialBot, User};
use revolt_database::{Bot, BotType, Channel, Database, Invite, Member, PartialBot, Server, User};
use revolt_models::v0;
use revolt_permissions::DEFAULT_PERMISSION_SERVER;
use revolt_quark::variables::delta::BOT_SERVER_PUBLIC_URL;
use revolt_result::{create_error, Result};
use rocket::serde::json::Json;
use rocket::State;
use std::collections::HashMap;
use ulid::Ulid;
use validator::Validate;

#[derive(Debug, serde::Serialize)]
struct CreatePromptBotReq {
user_id: String,
user_name: String,
bot_id: String,
bot_name: String,
bot_token: String,
model_name: String,
prompt_template: String,
temperature: f32,
}

/// # Create Bot
///
/// Create a new Revolt bot.
Expand Down Expand Up @@ -45,21 +61,111 @@ pub async fn create_bot(
}
}

owner.bot = Some(bot_information.into());
owner.bot = Some(bot_information.clone().into());

let (server, channel) = create_default_channel_for_bot(db, info.name.clone(), &owner).await?;

let mut invite_code: Option<String> = None;

if let Invite::Server { code, .. } =
Invite::create_channel_invite(db, user.id.clone(), &channel).await?
{
invite_code = Some(code);
}

let bot = Bot::create(
db,
info.name,
info.name.clone(),
&owner,
PartialBot {
bot_type: Some(bot_type),
bot_type: Some(bot_type.clone()),
server_invite: invite_code,
..Default::default()
},
)
.await?;

let bot_user = db.fetch_user(&bot.id).await?;

Member::create(db, &server, &user).await?;
Member::create(db, &server, &bot_user).await?;

if bot_type == BotType::PromptBot && !(*BOT_SERVER_PUBLIC_URL).is_empty() {
let _ = create_bot_at_bot_server(&bot, &bot_user, &user).await;
}

Ok(Json(bot.into()))
}

async fn create_default_channel_for_bot(
db: &Database,
bot_name: String,
user: &User,
) -> Result<(Server, Channel)> {
let channel_id = Ulid::new().to_string();
let server_id = Ulid::new().to_string();

let channel = Channel::TextChannel {
id: channel_id.clone(),
server: server_id.clone(),
name: "默认频道".into(),
description: None,
icon: None,
last_message_id: None,
default_permissions: None,
role_permissions: HashMap::new(),
nsfw: false,
};

channel.create(db).await?;

let server = Server {
id: server_id.clone(),
owner: user.id.clone(),
name: bot_name + "的社区",
description: None,
channels: vec![channel_id],
nsfw: false,
default_permissions: *DEFAULT_PERMISSION_SERVER as i64,
..Default::default()
};

server.create(db).await?;
Ok((server, channel))
}

async fn create_bot_at_bot_server(bot: &Bot, bot_user: &User, bot_owner: &User) -> Result<()> {
let model = bot_user.bot.as_ref().unwrap().model.as_ref().unwrap();

let data = CreatePromptBotReq {
user_id: bot_owner.id.clone(),
user_name: bot_owner.username.clone(),
bot_id: bot.id.clone(),
bot_name: bot_user.username.clone(),
bot_token: bot.token.clone(),
model_name: model.model_name.clone(),
prompt_template: model.prompts.system_prompt.clone(),
temperature: model.temperature,
};

let host = BOT_SERVER_PUBLIC_URL.to_string();
let url = format!("{host}/api/rest/v1/bot/create");
let client = reqwest::Client::new();
let response = client
.post(url.clone())
.json(&data)
.send()
.await
.map_err(|_| create_error!(InternalError))?
.text()
.await
.map_err(|_| create_error!(InternalError))?;

let data_json = serde_json::to_string(&data).map_err(|_| create_error!(InternalError))?;
info!("bot-server:\nurl:{url}\ndata:{data_json}\nresponse:{response}");
Ok(())
}

#[cfg(test)]
mod test {
use crate::{rocket, util::test::TestHarness};
Expand Down
Loading

0 comments on commit e465f00

Please sign in to comment.