Skip to content

Commit

Permalink
Merge pull request revoltchat#9 from psychopurp/dev
Browse files Browse the repository at this point in the history
feat(quark): support mongo:4.2.0
  • Loading branch information
psychopurp authored Sep 4, 2023
2 parents e465f00 + ee96899 commit 6c9a65e
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 55 deletions.
6 changes: 3 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ AUTUMN_PUBLIC_URL=http://local.revolt.chat:3000
# URL to where January is publicly available
JANUARY_PUBLIC_URL=http://local.revolt.chat:7000

# URL to bot server
BOT_SERVER_PUBLIC_URL=http://bot.server.com
# URL for prompt server
BOT_SERVER_PUBLIC_URL=http://prompt.server.com

# GPT-4 GPT-3
# GPT-4,GPT-3 Bots to add when user joined
OFFICIAL_MODEL_BOTS=01H7A1NB3P10J15DFEPQ7RN67J,01H6ZWPCCKQ4J46D088HBY5ZP4
# AIcoder Autoinstall
OFFICIAL_CUSTOM_BOTS=01H8DF1Y5CF0BA28D2TKBSSFGV
Expand Down
24 changes: 17 additions & 7 deletions crates/core/database/src/drivers/mongodb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,23 @@ impl MongoDb {
}
}

let query = doc! {
"$unset": unset,
"$set": if let Some(prefix) = &prefix {
to_document(&prefix_keys(&partial, prefix))
} else {
to_document(&partial)
}?
let query = if unset.is_empty() {
doc! {
"$set": if let Some(prefix) = &prefix {
to_document(&prefix_keys(&partial, prefix))
} else {
to_document(&partial)
}?
}
} else {
doc! {
"$unset": unset,
"$set": if let Some(prefix) = &prefix {
to_document(&prefix_keys(&partial, prefix))
} else {
to_document(&partial)
}?
}
};

self.col::<Document>(collection)
Expand Down
5 changes: 5 additions & 0 deletions crates/core/database/src/models/bots/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ auto_derived_partial!(
/// Bot server invite code
#[serde(skip_serializing_if = "Option::is_none")]
pub server_invite: Option<String>,

/// Bot's default server
#[serde(skip_serializing_if = "Option::is_none")]
pub default_server: Option<String>,
},
"PartialBot"
);
Expand Down Expand Up @@ -87,6 +91,7 @@ impl Default for Bot {
flags: Default::default(),
bot_type: Default::default(),
server_invite: Default::default(),
default_server: Default::default(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/core/database/src/util/bridge/v0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl From<crate::Bot> for Bot {
flags: value.flags.unwrap_or_default() as u32,
bot_type: value.bot_type.map(|x| x.into()),
server_invite: value.server_invite,
default_server: value.default_server,
}
}
}
Expand Down
33 changes: 31 additions & 2 deletions crates/core/models/src/v0/bots.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use super::users::BotModel;
use super::User;

use validator::Validate;

auto_derived!(
Expand Down Expand Up @@ -63,6 +62,10 @@ auto_derived!(
/// Bot server invite code
#[serde(skip_serializing_if = "Option::is_none")]
pub server_invite: Option<String>,

/// Bot's default server
#[serde(skip_serializing_if = "Option::is_none")]
pub default_server: Option<String>,
}

/// Optional fields on bot object
Expand Down Expand Up @@ -158,7 +161,6 @@ auto_derived_with_no_eq!(
}

/// Bot Details
#[derive(Default)]
#[cfg_attr(feature = "validator", derive(validator::Validate))]
pub struct DataCreateBot {
/// Bot username
Expand All @@ -184,3 +186,30 @@ auto_derived_with_no_eq!(
pub users: Vec<User>,
}
);

#[cfg(test)]
#[cfg(feature = "validator")]
mod tests {
use crate::v0::{BotModel, BotType, DataCreateBot, PromptTemplate};
use validator::Validate;

#[test]
fn test_validate() {
let mut bot = DataCreateBot {
name: "mybot".into(),
bot_type: Some(BotType::PromptBot),
model: Some(BotModel {
model_name: "gpt4".into(),
prompts: PromptTemplate {
system_prompt: "".into(),
},
temperature: 2.0,
}),
};

assert!(bot.validate().map_err(|e| println!("{e}")).is_err());

bot.model.as_mut().unwrap().temperature = 0.5;
assert!(bot.validate().map_err(|e| println!("{e}")).is_ok());
}
}
5 changes: 4 additions & 1 deletion crates/delta/src/routes/bots/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ pub async fn create_bot(
let (server, channel) = create_default_channel_for_bot(db, info.name.clone(), &owner).await?;

let mut invite_code: Option<String> = None;
let mut default_server: Option<String> = None;

if let Invite::Server { code, .. } =
if let Invite::Server { code, server, .. } =
Invite::create_channel_invite(db, user.id.clone(), &channel).await?
{
invite_code = Some(code);
default_server = Some(server)
}

let bot = Bot::create(
Expand All @@ -80,6 +82,7 @@ pub async fn create_bot(
PartialBot {
bot_type: Some(bot_type.clone()),
server_invite: invite_code,
default_server,
..Default::default()
},
)
Expand Down
58 changes: 43 additions & 15 deletions crates/delta/src/routes/channels/group_create.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use std::{collections::HashSet, iter::FromIterator};

use revolt_models::v0;
use revolt_quark::{
get_relationship,
models::{user::RelationshipStatus, Channel, User},
variables::delta::MAX_GROUP_SIZE,
Db, Error, Result,
models::user::{RelationshipStatus, User},
variables::delta::{MAX_GROUP_SIZE, OFFICIAL_MODEL_BOTS},
};

use rocket::serde::json::Json;
use revolt_database::{Channel, Database};
use revolt_result::{create_error, Result};
use rocket::{serde::json::Json, State};
use serde::{Deserialize, Serialize};
use ulid::Ulid;
use validator::Validate;
Expand Down Expand Up @@ -36,38 +38,44 @@ pub struct DataCreateGroup {
/// Create a new group channel.
#[openapi(tag = "Groups")]
#[post("/create", data = "<info>")]
pub async fn req(db: &Db, user: User, info: Json<DataCreateGroup>) -> Result<Json<Channel>> {
pub async fn req(
db: &State<Database>,
user: User,
info: Json<DataCreateGroup>,
) -> Result<Json<v0::Channel>> {
if user.bot.is_some() {
return Err(Error::IsBot);
return Err(create_error!(IsBot));
}

let info = info.into_inner();
info.validate()
.map_err(|error| Error::FailedValidation { error })?;
info.validate().map_err(|error| {
create_error!(FailedValidation {
error: error.to_string()
})
})?;

let mut set: HashSet<String> = HashSet::from_iter(info.users.into_iter());
set.insert(user.id.clone());

if set.len() > *MAX_GROUP_SIZE {
return Err(Error::GroupTooLarge {
return Err(create_error!(GroupTooLarge {
max: *MAX_GROUP_SIZE,
});
}));
}

for target in &set {
match get_relationship(&user, target) {
RelationshipStatus::Friend | RelationshipStatus::User => {}
_ => {
return Err(Error::NotFriends);
return Err(create_error!(NotFriends));
}
}
}

let group = Channel::Group {
let mut group = Channel::Group {
id: Ulid::new().to_string(),

name: info.name,
owner: user.id,
owner: user.id.clone(),
description: info.description,
recipients: set.into_iter().collect::<Vec<String>>(),

Expand All @@ -80,5 +88,25 @@ pub async fn req(db: &Db, user: User, info: Json<DataCreateGroup>) -> Result<Jso
};

group.create(db).await?;
Ok(Json(group))

add_official_prompt_bots(db, user.id.clone(), &mut group).await?;

Ok(Json(group.into()))
}

/// add official prompts bot for any new created group
async fn add_official_prompt_bots(
db: &Database,
user_id: String,
group: &mut Channel,
) -> Result<()> {
if (*OFFICIAL_MODEL_BOTS).is_empty() {
return Ok(());
}

for bot in db.fetch_users(OFFICIAL_MODEL_BOTS.as_slice()).await? {
group.add_user_to_group(&db.clone(), &bot, &user_id).await?;
}

Ok(())
}
64 changes: 50 additions & 14 deletions crates/delta/src/routes/onboard/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ pub async fn req(
error: error.to_string()
})
})?;

let new_user = User::create(db, data.username, session.user_id, None).await?;

prepare_on_board_data(db, new_user.id.clone()).await?;
Expand Down Expand Up @@ -75,17 +74,30 @@ async fn prepare_on_board_data(db: &Database, user_id: String) -> Result<()> {
};

group.create(db).await?;
let bot_users = db.fetch_users(OFFICIAL_MODEL_BOTS.as_slice()).await?;
for bot in bot_users {
group.add_user_to_group(&db.clone(), &bot, &user_id).await?;

for bot in db.fetch_users(OFFICIAL_MODEL_BOTS.as_slice()).await? {
group.add_user_to_group(db, &bot, &user_id).await?;

Channel::DirectMessage {
id: Ulid::new().to_string(),
active: false,
recipients: vec![bot.id, user_id.clone()],
last_message_id: None,
}
.create(db)
.await?;
}

Ok(())
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use crate::{rocket, routes::onboard::complete::DataOnboard, util::test::TestHarness};
use revolt_database::Channel;
use revolt_models::v0;
use revolt_quark::variables::delta::OFFICIAL_MODEL_BOTS;
use rocket::http::{ContentType, Header, Status};

Expand All @@ -94,14 +106,6 @@ mod tests {
let harness = TestHarness::new().await;
let (_, session) = harness.new_account_session().await;

let mut users = vec![];
users.extend((*OFFICIAL_MODEL_BOTS).clone());
for user_bot in users {
// let _ = harness.db.delete_user(&user_bot).await;
// let _ = User::create(&harness.db, user_bot.clone(), Some(user_bot.clone()), None).await;
println!("{user_bot}");
}

let response = harness
.client
.post("/onboard/complete")
Expand All @@ -115,8 +119,40 @@ mod tests {
)
.dispatch()
.await;

assert_eq!(response.status(), Status::Ok);
let status = response.status();
// println!("{:}", response.into_string().await.unwrap());
assert_eq!(status, Status::Ok);

let user = response.into_json::<v0::User>().await.unwrap();
let channels = harness.db.find_direct_messages(&user.id).await.unwrap();

assert_eq!(channels.len(), 3);

let mut match_cnt = 0;

for channel in channels.into_iter() {
match channel {
Channel::Group {
owner, recipients, ..
} => {
assert_eq!(owner, user.id);
let set: HashSet<String> = recipients.into_iter().collect();
let mut expect = HashSet::new();
expect.insert(user.id.clone());
for id in OFFICIAL_MODEL_BOTS.as_slice() {
expect.insert(id.clone());
}
assert_eq!(set, expect);
match_cnt += 1;
}

Channel::DirectMessage { .. } => {
match_cnt += 1;
}
_ => panic!("error"),
}
}

assert_eq!(3, match_cnt);
}
}
Loading

0 comments on commit 6c9a65e

Please sign in to comment.