diff --git a/Cargo.toml b/Cargo.toml index 9dc995b..1c76203 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gorse_rs" -version = "0.4.1" +version = "0.5.0" edition = "2021" description = "Rust SDK for gorse recommender system" readme = "README.md" @@ -14,8 +14,10 @@ categories = ["algorithms", "science"] reqwest = { version = "0.12", features = ["blocking", "json"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +serde_url_params = "0.2" [dev-dependencies] -chrono = "0.4.23" -redis = "0.25.1" -tokio = { version = "1.22.0", features = ["macros"] } +chrono = "0.4.38" +redis = "0.25.4" +serial_test = "3.1.1" +tokio = { version = "1.38.0", features = ["macros"] } diff --git a/src/async.rs b/src/async.rs new file mode 100644 index 0000000..462e99a --- /dev/null +++ b/src/async.rs @@ -0,0 +1,960 @@ +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use crate::{ + query::{CursorQuery, OffsetQuery, UserIdQuery, WriteBackQuery}, + Error, Feedback, Feedbacks, Health, Item, Items, Method, Result, RowAffected, Score, + StatusCode, User, Users, +}; + +#[derive(Debug, Clone)] +pub struct Gorse { + entry_point: String, + api_key: String, + client: Client, +} + +impl Gorse { + pub fn new(entry_point: impl Into, api_key: impl Into) -> Self { + Self { + entry_point: entry_point.into(), + api_key: api_key.into(), + client: Client::new(), + } + } + + pub async fn insert_user(&self, user: &User) -> Result { + return self + .request(Method::POST, format!("{}api/user", self.entry_point), user) + .await; + } + + pub async fn get_user(&self, user_id: &str) -> Result { + return self + .request::<(), User>( + Method::GET, + format!("{}api/user/{}", self.entry_point, user_id), + &(), + ) + .await; + } + + pub async fn delete_user(&self, user_id: &str) -> Result { + return self + .request::<(), RowAffected>( + Method::DELETE, + format!("{}api/user/{}", self.entry_point, user_id), + &(), + ) + .await; + } + + pub async fn update_user(&self, user: &User) -> Result { + return self + .request( + Method::PATCH, + format!("{}api/user/{}", self.entry_point, user.user_id), + user, + ) + .await; + } + + pub async fn list_users(&self, query: &CursorQuery) -> Result> { + return self + .request::<(), Users>( + Method::GET, + format!( + "{}api/users?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .await + .map(|users| users.users); + } + + pub async fn insert_users(&self, users: &Vec) -> Result { + return self + .request( + Method::POST, + format!("{}api/users", self.entry_point), + users, + ) + .await; + } + + pub async fn insert_item(&self, item: &Item) -> Result { + return self + .request(Method::POST, format!("{}api/item", self.entry_point), item) + .await; + } + + pub async fn get_item(&self, item_id: &str) -> Result { + return self + .request::<(), Item>( + Method::GET, + format!("{}api/item/{}", self.entry_point, item_id), + &(), + ) + .await; + } + + pub async fn delete_item(&self, item_id: &str) -> Result { + return self + .request::<(), RowAffected>( + Method::DELETE, + format!("{}api/item/{}", self.entry_point, item_id), + &(), + ) + .await; + } + + pub async fn update_item(&self, item: &Item) -> Result { + return self + .request( + Method::PATCH, + format!("{}api/item/{}", self.entry_point, item.item_id), + item, + ) + .await; + } + + pub async fn add_item_to_category(&self, item_id: &str, category: &str) -> Result { + return self + .request( + Method::PUT, + format!( + "{}api/item/{}/category/{}", + self.entry_point, item_id, category + ), + &(), + ) + .await; + } + + pub async fn delete_item_to_category( + &self, + item_id: &str, + category: &str, + ) -> Result { + return self + .request( + Method::DELETE, + format!( + "{}api/item/{}/category/{}", + self.entry_point, item_id, category + ), + &(), + ) + .await; + } + + pub async fn list_items(&self, query: &CursorQuery) -> Result> { + return self + .request::<(), Items>( + Method::GET, + format!( + "{}api/items?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .await + .map(|items| items.items); + } + + pub async fn insert_items(&self, items: &Vec) -> Result { + return self + .request( + Method::POST, + format!("{}api/items", self.entry_point), + items, + ) + .await; + } + + pub async fn list_feedback(&self, query: &CursorQuery) -> Result> { + return self + .request::<(), Feedbacks>( + Method::GET, + format!( + "{}api/feedback?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .await + .map(|feedbacks| feedbacks.feedbacks); + } + + pub async fn overwrite_feedback(&self, feedback: &Vec) -> Result { + return self + .request( + Method::PUT, + format!("{}api/feedback", self.entry_point), + feedback, + ) + .await; + } + + pub async fn insert_feedback(&self, feedback: &Vec) -> Result { + return self + .request( + Method::POST, + format!("{}api/feedback", self.entry_point), + feedback, + ) + .await; + } + + pub async fn list_feedback_by_type( + &self, + feedback_type: &str, + query: &CursorQuery, + ) -> Result> { + return self + .request::<(), Feedbacks>( + Method::GET, + format!( + "{}api/feedback/{}?{}", + self.entry_point, + feedback_type, + serde_url_params::to_string(query).unwrap() + ), + &(), + ) + .await + .map(|feedbacks| feedbacks.feedbacks); + } + + pub async fn get_feedback( + &self, + feedback_type: &str, + user_id: &str, + item_id: &str, + ) -> Result { + return self + .request::<(), Feedback>( + Method::GET, + format!( + "{}api/feedback/{}/{}/{}", + self.entry_point, feedback_type, user_id, item_id, + ), + &(), + ) + .await; + } + + pub async fn delete_feedback( + &self, + feedback_type: &str, + user_id: &str, + item_id: &str, + ) -> Result { + return self + .request::<(), RowAffected>( + Method::DELETE, + format!( + "{}api/feedback/{}/{}/{}", + self.entry_point, feedback_type, user_id, item_id, + ), + &(), + ) + .await; + } + + pub async fn list_feedback_from_user_by_item( + &self, + user_id: &str, + item_id: &str, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!("{}api/feedback/{}/{}", self.entry_point, user_id, item_id), + &(), + ) + .await; + } + + pub async fn delete_feedback_from_user_by_item( + &self, + user_id: &str, + item_id: &str, + ) -> Result { + return self + .request::<(), RowAffected>( + Method::DELETE, + format!("{}api/feedback/{}/{}", self.entry_point, user_id, item_id), + &(), + ) + .await; + } + + pub async fn list_feedback_by_item(&self, item_id: &str) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!("{}api/item/{}/feedback", self.entry_point, item_id), + &(), + ) + .await; + } + + pub async fn list_feedback_by_item_and_type( + &self, + item_id: &str, + feedback_type: &str, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/item/{}/feedback/{}", + self.entry_point, item_id, feedback_type + ), + &(), + ) + .await; + } + + pub async fn list_feedback_from_user(&self, user_id: &str) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!("{}api/user/{}/feedback", self.entry_point, user_id), + &(), + ) + .await; + } + + pub async fn list_feedback_from_user_by_type( + &self, + user_id: &str, + feedback_type: &str, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/user/{}/feedback/{}", + self.entry_point, user_id, feedback_type + ), + &(), + ) + .await; + } + + pub async fn get_item_neighbors( + &self, + item_id: &str, + query: &OffsetQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/item/{}/neighbors?{}", + self.entry_point, + item_id, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_item_neighbors_by_category( + &self, + item_id: &str, + category: &str, + query: &OffsetQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/item/{}/neighbors/{}?{}", + self.entry_point, + item_id, + category, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_latest(&self, query: &UserIdQuery) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/latest?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_latest_by_category( + &self, + category: &str, + query: &UserIdQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/latest/{}?{}", + self.entry_point, + category, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_popular(&self, query: &UserIdQuery) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/popular?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_popular_by_category( + &self, + category: &str, + query: &UserIdQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/popular/{}?{}", + self.entry_point, + category, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_recommend( + &self, + user_id: &str, + query: &WriteBackQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/recommend/{}?{}", + self.entry_point, + user_id, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_recommend_by_category( + &self, + user_id: &str, + category: &str, + query: &WriteBackQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/recommend/{}/{}?{}", + self.entry_point, + user_id, + category, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + pub async fn get_recommend_session( + &self, + feedbacks: &Vec, + query: &OffsetQuery, + ) -> Result> { + return self + .request::, Option>>( + Method::POST, + format!( + "{}api/session/recommend?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + feedbacks, + ) + .await + .map(|scores| scores.unwrap_or_default()); + } + + pub async fn get_recommend_session_by_category( + &self, + feedbacks: &Vec, + category: &str, + query: &OffsetQuery, + ) -> Result> { + return self + .request::, Option>>( + Method::POST, + format!( + "{}api/session/recommend/{}?{}", + self.entry_point, + category, + serde_url_params::to_string(query)? + ), + feedbacks, + ) + .await + .map(|scores| scores.unwrap_or_default()); + } + + pub async fn get_user_neighbors( + &self, + user_id: &str, + query: &OffsetQuery, + ) -> Result> { + return self + .request::<(), Vec>( + Method::GET, + format!( + "{}api/user/{}/neighbors?{}", + self.entry_point, + user_id, + serde_url_params::to_string(query)? + ), + &(), + ) + .await; + } + + async fn request Deserialize<'a>>( + &self, + method: Method, + url: String, + body: &BodyType, + ) -> Result { + let response = self + .client + .request(method, url) + .header("X-API-Key", self.api_key.as_str()) + .header("Content-Type", "application/json") + .json(body) + .send() + .await?; + return if response.status() == StatusCode::OK { + let r: RetType = serde_json::from_str(response.text().await?.as_str())?; + Ok(r) + } else { + Err(Box::new(Error { + status_code: response.status(), + message: response.text().await?, + })) + }; + } + + pub async fn is_live(&self) -> Result { + return self + .request( + Method::GET, + format!("{}api/health/live", self.entry_point), + &(), + ) + .await; + } + + pub async fn is_ready(&self) -> Result { + return self + .request( + Method::GET, + format!("{}api/health/ready", self.entry_point), + &(), + ) + .await; + } +} + +#[cfg(test)] +mod tests { + use redis::Commands; + use serial_test::serial; + + use super::*; + + const ENTRY_POINT: &str = "http://127.0.0.1:8088/"; + const API_KEY: &str = "zhenghaoz"; + + #[tokio::test] + #[serial] + async fn test_users() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let mut user = User::new("1", vec!["a", "b", "c"]); + // Insert a user. + let rows_affected = client.insert_user(&user).await?; + assert_eq!(rows_affected.row_affected, 1); + // Get this user. + let return_user = client.get_user("1").await?; + assert_eq!(return_user, user); + // Update this user. + user.labels = vec!["e".into(), "f".into(), "g".into()]; + let rows_affected = client.update_user(&user).await?; + assert_eq!(rows_affected.row_affected, 1); + // Get this user. + let return_user = client.get_user("1").await?; + assert_eq!(return_user, user); + // Delete this user. + let rows_affected = client.delete_user("1").await?; + assert_eq!(rows_affected.row_affected, 1); + let response = client.get_user("1").await; + assert!(response.is_err()); + // Insert a users. + let users = vec![user, User::new("12", vec!["a", "b", "c"])]; + let rows_affected = client.insert_users(&users).await?; + assert_eq!(rows_affected.row_affected, 2); + // Get this users. + let return_users = client.list_users(&CursorQuery::new()).await?; + assert!(!return_users.is_empty()); + // Delete this users. + let rows_affected = client.delete_user("1").await?; + assert_eq!(rows_affected.row_affected, 1); + let rows_affected = client.delete_user("12").await?; + assert_eq!(rows_affected.row_affected, 1); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_items() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let category = "test".to_string(); + let item = Item::new( + "1", + vec!["a", "b", "c"], + vec!["d", "e"], + "2022-11-20T13:55:27Z", + ) + .comment("comment"); + let item_with_category = Item::new( + "1", + vec!["a", "b", "c"], + vec!["d", "e", &category], + "2022-11-20T13:55:27Z", + ) + .comment("comment"); + // Insert an item. + let rows_affected = client.insert_item(&item).await?; + assert_eq!(rows_affected.row_affected, 1); + // Add category to item. + let rows_affected = client + .add_item_to_category(&item.item_id, &category) + .await?; + assert_eq!(rows_affected.row_affected, 1); + // Get this item. + let return_item = client.get_item("1").await?; + assert_eq!(return_item, item_with_category); + // Delete category to item. + let rows_affected = client + .delete_item_to_category(&item.item_id, &category) + .await?; + assert_eq!(rows_affected.row_affected, 1); + // Get this item. + let return_item = client.get_item("1").await?; + assert_eq!(return_item, item); + // Delete this item. + let rows_affected = client.delete_item("1").await?; + assert_eq!(rows_affected.row_affected, 1); + let response = client.get_item("1").await; + assert!(response.is_err()); + // Insert a items. + let items = vec![ + item, + Item::new( + "12", + vec!["d", "e"], + vec!["a", "b", "c"], + "2023-11-20T13:55:27Z", + ), + ]; + let rows_affected = client.insert_items(&items).await?; + assert_eq!(rows_affected.row_affected, 2); + // Get this items. + let return_items = client.list_items(&CursorQuery::new()).await?; + assert!(!return_items.is_empty()); + // Delete this items. + let rows_affected = client.delete_item("1").await?; + assert_eq!(rows_affected.row_affected, 1); + let rows_affected = client.delete_item("12").await?; + assert_eq!(rows_affected.row_affected, 1); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_feedback() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let all_feedback = vec![ + Feedback::new("read", "10", "3", "2022-11-20T13:55:27Z"), + Feedback::new("read", "10", "4", "2022-11-20T13:55:27Z"), + ]; + let feedback = vec![ + Feedback::new("read", "10", "3", "2022-11-20T13:55:27Z"), + Feedback::new("read", "10", "4", "2022-11-20T13:55:27Z"), + Feedback::new("star", "10", "3", "2022-11-20T13:55:27Z"), + Feedback::new("read", "10", "5", "2022-11-20T13:55:27Z"), + Feedback::new("star", "10", "5", "2022-11-20T13:55:27Z"), + ]; + // Insert feedback. + let rows_affected = client.insert_feedback(&feedback).await?; + assert_eq!(rows_affected.row_affected, 5); + // Overwrite feedback. + let rows_affected = client.overwrite_feedback(&feedback).await?; + assert_eq!(rows_affected.row_affected, 5); + // Delete feedback. + let rows_affected = client.delete_feedback("star", "10", "3").await?; + assert_eq!(rows_affected.row_affected, 1); + let response = client.get_feedback("star", "10", "3").await; + assert!(response.is_err()); + // Delete feedback from user by item. + let rows_affected = client.delete_feedback_from_user_by_item("10", "5").await?; + assert_eq!(rows_affected.row_affected, 2); + // List feedback. + let return_feedback = client.list_feedback(&CursorQuery::new()).await?; + assert!(all_feedback + .iter() + .all(|feedback| return_feedback.contains(feedback))); + // List feedback by type. + let return_feedback = client + .list_feedback_by_type("read", &CursorQuery::new()) + .await?; + assert!(all_feedback + .iter() + .all(|feedback| return_feedback.contains(feedback))); + // Get feedback. + let return_feedback = client.get_feedback("read", "10", "3").await?; + assert_eq!(return_feedback, feedback[0]); + // List feedback by item. + let return_feedback = client.list_feedback_by_item("3").await?; + assert_eq!(return_feedback, feedback[..1]); + // List feedback by item and type. + let return_feedback = client.list_feedback_by_item_and_type("3", "read").await?; + assert_eq!(return_feedback, feedback[..1]); + // List feedback from user. + let return_feedback = client.list_feedback_from_user("10").await?; + assert_eq!(return_feedback, feedback[..2]); + // List feedback from user by item. + let return_feedback = client.list_feedback_from_user_by_item("10", "3").await?; + assert_eq!(return_feedback, feedback[..1]); + // List feedback from user by type. + let return_feedback = client.list_feedback_from_user_by_type("10", "read").await?; + assert_eq!(return_feedback, feedback[..2]); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_neighbors() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("user_neighbors/10")?; + connection.del("item_neighbors/10")?; + connection.del("item_neighbors/10/test")?; + connection.zadd_multiple("user_neighbors/10", &[(1, 10), (2, 20), (3, 30)])?; + connection.zadd_multiple("item_neighbors/10", &[(1, 10), (2, 20), (3, 30)])?; + connection.zadd_multiple("item_neighbors/10/test", &[(1, 10), (2, 20), (3, 30)])?; + let scores = vec![ + Score { + id: "30".into(), + score: 3.0, + }, + Score { + id: "20".into(), + score: 2.0, + }, + Score { + id: "10".into(), + score: 1.0, + }, + ]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get item neighbors. + let returned_scores = client.get_item_neighbors("10", &OffsetQuery::new()).await?; + assert_eq!(returned_scores, scores); + // Get item neighbors by category. + let returned_scores = client + .get_item_neighbors_by_category("10", "test", &OffsetQuery::new()) + .await?; + assert_eq!(returned_scores, scores); + // Get user neighbors. + let returned_scores = client.get_user_neighbors("10", &OffsetQuery::new()).await?; + assert_eq!(returned_scores, scores); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_latest() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("latest_items")?; + connection.del("latest_items/test")?; + connection.zadd_multiple("latest_items", &[(1, 10), (2, 20), (3, 30)])?; + connection.zadd_multiple("latest_items/test", &[(1, 10), (2, 20), (3, 30)])?; + let scores = vec![ + Score { + id: "30".into(), + score: 3.0, + }, + Score { + id: "20".into(), + score: 2.0, + }, + Score { + id: "10".into(), + score: 1.0, + }, + ]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get latest. + let returned_scores = client.get_latest(&UserIdQuery::new()).await?; + assert_eq!(returned_scores, scores); + // Get latest by category. + let returned_scores = client + .get_latest_by_category("test", &UserIdQuery::new()) + .await?; + assert_eq!(returned_scores, scores); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_popular() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("popular_items")?; + connection.del("popular_items/test")?; + connection.zadd_multiple("popular_items", &[(1, 10), (2, 20), (3, 30)])?; + connection.zadd_multiple("popular_items/test", &[(1, 10), (2, 20), (3, 30)])?; + let scores = vec![ + Score { + id: "30".into(), + score: 3.0, + }, + Score { + id: "20".into(), + score: 2.0, + }, + Score { + id: "10".into(), + score: 1.0, + }, + ]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get popular. + let returned_scores = client.get_popular(&UserIdQuery::new()).await?; + assert_eq!(returned_scores, scores); + // Get popular by category. + let returned_scores = client + .get_popular_by_category("test", &UserIdQuery::new()) + .await?; + assert_eq!(returned_scores, scores); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_recommend() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("offline_recommend/10")?; + connection.del("offline_recommend/10/test")?; + connection.zadd_multiple("offline_recommend/10", &[(1, 10), (2, 20), (3, 30)])?; + connection.zadd_multiple("offline_recommend/10/test", &[(1, 10), (2, 20), (3, 30)])?; + let items = vec!["30".to_string(), "20".to_string(), "10".to_string()]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get recommendation. + let returned_items = client.get_recommend("10", &WriteBackQuery::new()).await?; + assert_eq!(returned_items, items); + // Get recommendation by category. + let returned_items = client + .get_recommend_by_category("10", "test", &WriteBackQuery::new()) + .await?; + assert_eq!(returned_items, items); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_recommend_session() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let items = vec![Item::new( + "9", + vec!["a", "b", "c"], + vec!["d", "e"], + "2022-11-20T13:55:27Z", + )]; + let feedbacks = vec![Feedback::new("read", "10", "9", "2022-11-21T13:55:27Z")]; + // Insert an item. + let rows_affected = client.insert_items(&items).await?; + assert_eq!(rows_affected.row_affected, 1); + // Get recommendation. + let returned_scores = client + .get_recommend_session(&feedbacks, &OffsetQuery::new()) + .await?; + assert!(returned_scores.is_empty()); + // Get recommendation by category. + let returned_scores = client + .get_recommend_session_by_category(&feedbacks, "test", &OffsetQuery::new()) + .await?; + assert!(returned_scores.is_empty()); + // Delete a feedback. + let rows_affected = client.delete_feedback("read", "10", "9").await?; + assert_eq!(rows_affected.row_affected, 0); + // Delete an item. + let rows_affected = client.delete_item("9").await?; + assert_eq!(rows_affected.row_affected, 1); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_health() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let health = Health { + cache_store_connected: true, + cache_store_error: None, + data_store_connected: true, + data_store_error: None, + ready: true, + }; + let return_health = client.is_live().await?; + assert_eq!(return_health, health); + let return_health = client.is_ready().await?; + assert_eq!(return_health, health); + Ok(()) + } +} diff --git a/src/blocking.rs b/src/blocking.rs new file mode 100644 index 0000000..2ffca70 --- /dev/null +++ b/src/blocking.rs @@ -0,0 +1,867 @@ +use reqwest::blocking::Client; +use serde::{Deserialize, Serialize}; + +use crate::{ + query::{CursorQuery, OffsetQuery, UserIdQuery, WriteBackQuery}, + Error, Feedback, Feedbacks, Health, Item, Items, Method, Result, RowAffected, Score, + StatusCode, User, Users, +}; + +#[derive(Debug, Clone)] +pub struct Gorse { + entry_point: String, + api_key: String, + client: Client, +} + +impl Gorse { + pub fn new(entry_point: impl Into, api_key: impl Into) -> Self { + Self { + entry_point: entry_point.into(), + api_key: api_key.into(), + client: Client::new(), + } + } + + pub fn insert_user(&self, user: &User) -> Result { + return self.request(Method::POST, format!("{}api/user", self.entry_point), user); + } + + pub fn get_user(&self, user_id: &str) -> Result { + return self.request::<(), User>( + Method::GET, + format!("{}api/user/{}", self.entry_point, user_id), + &(), + ); + } + + pub fn delete_user(&self, user_id: &str) -> Result { + return self.request::<(), RowAffected>( + Method::DELETE, + format!("{}api/user/{}", self.entry_point, user_id), + &(), + ); + } + + pub fn update_user(&self, user: &User) -> Result { + return self.request( + Method::PATCH, + format!("{}api/user/{}", self.entry_point, user.user_id), + user, + ); + } + + pub fn list_users(&self, query: &CursorQuery) -> Result> { + return self + .request::<(), Users>( + Method::GET, + format!( + "{}api/users?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .map(|users| users.users); + } + + pub fn insert_users(&self, users: &Vec) -> Result { + return self.request( + Method::POST, + format!("{}api/users", self.entry_point), + users, + ); + } + + pub fn insert_item(&self, item: &Item) -> Result { + return self.request(Method::POST, format!("{}api/item", self.entry_point), item); + } + + pub fn get_item(&self, item_id: &str) -> Result { + return self.request::<(), Item>( + Method::GET, + format!("{}api/item/{}", self.entry_point, item_id), + &(), + ); + } + + pub fn delete_item(&self, item_id: &str) -> Result { + return self.request::<(), RowAffected>( + Method::DELETE, + format!("{}api/item/{}", self.entry_point, item_id), + &(), + ); + } + + pub fn update_item(&self, item: &Item) -> Result { + return self.request( + Method::PATCH, + format!("{}api/item/{}", self.entry_point, item.item_id), + item, + ); + } + + pub fn add_item_to_category(&self, item_id: &str, category: &str) -> Result { + return self.request( + Method::PUT, + format!( + "{}api/item/{}/category/{}", + self.entry_point, item_id, category + ), + &(), + ); + } + + pub fn delete_item_to_category(&self, item_id: &str, category: &str) -> Result { + return self.request( + Method::DELETE, + format!( + "{}api/item/{}/category/{}", + self.entry_point, item_id, category + ), + &(), + ); + } + + pub fn list_items(&self, query: &CursorQuery) -> Result> { + return self + .request::<(), Items>( + Method::GET, + format!( + "{}api/items?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .map(|items| items.items); + } + + pub fn insert_items(&self, items: &Vec) -> Result { + return self.request( + Method::POST, + format!("{}api/items", self.entry_point), + items, + ); + } + + pub fn list_feedback(&self, query: &CursorQuery) -> Result> { + return self + .request::<(), Feedbacks>( + Method::GET, + format!( + "{}api/feedback?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ) + .map(|feedbacks| feedbacks.feedbacks); + } + + pub fn overwrite_feedback(&self, feedback: &Vec) -> Result { + return self.request( + Method::PUT, + format!("{}api/feedback", self.entry_point), + feedback, + ); + } + + pub fn insert_feedback(&self, feedback: &Vec) -> Result { + return self.request( + Method::POST, + format!("{}api/feedback", self.entry_point), + feedback, + ); + } + + pub fn list_feedback_by_type( + &self, + feedback_type: &str, + query: &CursorQuery, + ) -> Result> { + return self + .request::<(), Feedbacks>( + Method::GET, + format!( + "{}api/feedback/{}?{}", + self.entry_point, + feedback_type, + serde_url_params::to_string(query).unwrap() + ), + &(), + ) + .map(|feedbacks| feedbacks.feedbacks); + } + + pub fn get_feedback( + &self, + feedback_type: &str, + user_id: &str, + item_id: &str, + ) -> Result { + return self.request::<(), Feedback>( + Method::GET, + format!( + "{}api/feedback/{}/{}/{}", + self.entry_point, feedback_type, user_id, item_id, + ), + &(), + ); + } + + pub fn delete_feedback( + &self, + feedback_type: &str, + user_id: &str, + item_id: &str, + ) -> Result { + return self.request::<(), RowAffected>( + Method::DELETE, + format!( + "{}api/feedback/{}/{}/{}", + self.entry_point, feedback_type, user_id, item_id, + ), + &(), + ); + } + + pub fn list_feedback_from_user_by_item( + &self, + user_id: &str, + item_id: &str, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!("{}api/feedback/{}/{}", self.entry_point, user_id, item_id), + &(), + ); + } + + pub fn delete_feedback_from_user_by_item( + &self, + user_id: &str, + item_id: &str, + ) -> Result { + return self.request::<(), RowAffected>( + Method::DELETE, + format!("{}api/feedback/{}/{}", self.entry_point, user_id, item_id), + &(), + ); + } + + pub fn list_feedback_by_item(&self, item_id: &str) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!("{}api/item/{}/feedback", self.entry_point, item_id), + &(), + ); + } + + pub fn list_feedback_by_item_and_type( + &self, + item_id: &str, + feedback_type: &str, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/item/{}/feedback/{}", + self.entry_point, item_id, feedback_type + ), + &(), + ); + } + + pub fn list_feedback_from_user(&self, user_id: &str) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!("{}api/user/{}/feedback", self.entry_point, user_id), + &(), + ); + } + + pub fn list_feedback_from_user_by_type( + &self, + user_id: &str, + feedback_type: &str, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/user/{}/feedback/{}", + self.entry_point, user_id, feedback_type + ), + &(), + ); + } + + pub fn get_item_neighbors(&self, item_id: &str, query: &OffsetQuery) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/item/{}/neighbors?{}", + self.entry_point, + item_id, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_item_neighbors_by_category( + &self, + item_id: &str, + category: &str, + query: &OffsetQuery, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/item/{}/neighbors/{}?{}", + self.entry_point, + item_id, + category, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_latest(&self, query: &UserIdQuery) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/latest?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_latest_by_category( + &self, + category: &str, + query: &UserIdQuery, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/latest/{}?{}", + self.entry_point, + category, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_popular(&self, query: &UserIdQuery) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/popular?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_popular_by_category( + &self, + category: &str, + query: &UserIdQuery, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/popular/{}?{}", + self.entry_point, + category, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_recommend(&self, user_id: &str, query: &WriteBackQuery) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/recommend/{}?{}", + self.entry_point, + user_id, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_recommend_by_category( + &self, + user_id: &str, + category: &str, + query: &WriteBackQuery, + ) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/recommend/{}/{}?{}", + self.entry_point, + user_id, + category, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + pub fn get_recommend_session( + &self, + feedbacks: &Vec, + query: &OffsetQuery, + ) -> Result> { + return self + .request::, Option>>( + Method::POST, + format!( + "{}api/session/recommend?{}", + self.entry_point, + serde_url_params::to_string(query)? + ), + feedbacks, + ) + .map(|scores| scores.unwrap_or_default()); + } + + pub fn get_recommend_session_by_category( + &self, + feedbacks: &Vec, + category: &str, + query: &OffsetQuery, + ) -> Result> { + return self + .request::, Option>>( + Method::POST, + format!( + "{}api/session/recommend/{}?{}", + self.entry_point, + category, + serde_url_params::to_string(query)? + ), + feedbacks, + ) + .map(|scores| scores.unwrap_or_default()); + } + + pub fn get_user_neighbors(&self, user_id: &str, query: &OffsetQuery) -> Result> { + return self.request::<(), Vec>( + Method::GET, + format!( + "{}api/user/{}/neighbors?{}", + self.entry_point, + user_id, + serde_url_params::to_string(query)? + ), + &(), + ); + } + + fn request Deserialize<'a>>( + &self, + method: Method, + url: String, + body: &BodyType, + ) -> Result { + let response = self + .client + .request(method, url) + .header("X-API-Key", self.api_key.as_str()) + .header("Content-Type", "application/json") + .json(body) + .send()?; + return if response.status() == StatusCode::OK { + let r: RetType = serde_json::from_str(response.text()?.as_str())?; + Ok(r) + } else { + Err(Box::new(Error { + status_code: response.status(), + message: response.text()?, + })) + }; + } + + pub fn is_live(&self) -> Result { + return self.request( + Method::GET, + format!("{}api/health/live", self.entry_point), + &(), + ); + } + + pub fn is_ready(&self) -> Result { + return self.request( + Method::GET, + format!("{}api/health/ready", self.entry_point), + &(), + ); + } +} + +#[cfg(test)] +mod tests { + use redis::Commands; + use serial_test::serial; + + use super::*; + + const ENTRY_POINT: &str = "http://127.0.0.1:8088/"; + const API_KEY: &str = "zhenghaoz"; + + #[test] + #[serial] + fn test_users() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let mut user = User::new("100", vec!["a", "b", "c"]); + // Insert a user. + let rows_affected = client.insert_user(&user)?; + assert_eq!(rows_affected.row_affected, 1); + // Get this user. + let return_user = client.get_user("100")?; + assert_eq!(return_user, user); + // Update this user. + user.labels = vec!["e".into(), "f".into(), "g".into()]; + let rows_affected = client.update_user(&user)?; + assert_eq!(rows_affected.row_affected, 1); + // Get this user. + let return_user = client.get_user("100")?; + assert_eq!(return_user, user); + // Delete this user. + let rows_affected = client.delete_user("100")?; + assert_eq!(rows_affected.row_affected, 1); + let response = client.get_user("100"); + assert!(response.is_err()); + // Insert a users. + let users = vec![user, User::new("102", vec!["a", "b", "c"])]; + let rows_affected = client.insert_users(&users)?; + assert_eq!(rows_affected.row_affected, 2); + // Get this users. + let return_users = client.list_users(&CursorQuery::new())?; + assert!(!return_users.is_empty()); + // Delete this users. + let rows_affected = client.delete_user("100")?; + assert_eq!(rows_affected.row_affected, 1); + let rows_affected = client.delete_user("102")?; + assert_eq!(rows_affected.row_affected, 1); + Ok(()) + } + + #[test] + #[serial] + fn test_items() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let category = "test".to_string(); + let item = Item::new( + "100", + vec!["a", "b", "c"], + vec!["d", "e"], + "2022-11-20T13:55:27Z", + ) + .comment("comment"); + let item_with_category = Item::new( + "100", + vec!["a", "b", "c"], + vec!["d", "e", &category], + "2022-11-20T13:55:27Z", + ) + .comment("comment"); + // Insert an item. + let rows_affected = client.insert_item(&item)?; + assert_eq!(rows_affected.row_affected, 1); + // Add category to item. + let rows_affected = client.add_item_to_category(&item.item_id, &category)?; + assert_eq!(rows_affected.row_affected, 1); + // Get this item. + let return_item = client.get_item("100")?; + assert_eq!(return_item, item_with_category); + // Delete category to item. + let rows_affected = client.delete_item_to_category(&item.item_id, &category)?; + assert_eq!(rows_affected.row_affected, 1); + // Get this item. + let return_item = client.get_item("100")?; + assert_eq!(return_item, item); + // Delete this item. + let rows_affected = client.delete_item("100")?; + assert_eq!(rows_affected.row_affected, 1); + let response = client.get_item("100"); + assert!(response.is_err()); + // Insert a items. + let items = vec![ + item, + Item::new( + "102", + vec!["d", "e"], + vec!["a", "b", "c"], + "2023-11-20T13:55:27Z", + ), + ]; + let rows_affected = client.insert_items(&items)?; + assert_eq!(rows_affected.row_affected, 2); + // Get this items. + let return_items = client.list_items(&CursorQuery::new())?; + assert!(!return_items.is_empty()); + // Delete this items. + let rows_affected = client.delete_item("100")?; + assert_eq!(rows_affected.row_affected, 1); + let rows_affected = client.delete_item("102")?; + assert_eq!(rows_affected.row_affected, 1); + Ok(()) + } + + #[test] + #[serial] + fn test_feedback() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let all_feedback = vec![ + Feedback::new("read", "1000", "300", "2022-11-20T13:55:27Z"), + Feedback::new("read", "1000", "400", "2022-11-20T13:55:27Z"), + ]; + let feedback = vec![ + Feedback::new("read", "1000", "300", "2022-11-20T13:55:27Z"), + Feedback::new("read", "1000", "400", "2022-11-20T13:55:27Z"), + Feedback::new("star", "1000", "300", "2022-11-20T13:55:27Z"), + Feedback::new("read", "1000", "500", "2022-11-20T13:55:27Z"), + Feedback::new("star", "1000", "500", "2022-11-20T13:55:27Z"), + ]; + // Insert feedback. + let rows_affected = client.insert_feedback(&feedback)?; + assert_eq!(rows_affected.row_affected, 5); + // Overwrite feedback. + let rows_affected = client.overwrite_feedback(&feedback)?; + assert_eq!(rows_affected.row_affected, 5); + // Delete feedback. + let rows_affected = client.delete_feedback("star", "1000", "300")?; + assert_eq!(rows_affected.row_affected, 1); + let response = client.get_feedback("star", "1000", "300"); + assert!(response.is_err()); + // Delete feedback from user by item. + let rows_affected = client.delete_feedback_from_user_by_item("1000", "500")?; + assert_eq!(rows_affected.row_affected, 2); + // List feedback. + let return_feedback = client.list_feedback(&CursorQuery::new())?; + assert!(all_feedback + .iter() + .all(|feedback| return_feedback.contains(feedback))); + // List feedback by type. + let return_feedback = client.list_feedback_by_type("read", &CursorQuery::new())?; + assert!(all_feedback + .iter() + .all(|feedback| return_feedback.contains(feedback))); + // Get feedback. + let return_feedback = client.get_feedback("read", "1000", "300")?; + assert_eq!(return_feedback, feedback[0]); + // List feedback by item. + let return_feedback = client.list_feedback_by_item("300")?; + assert_eq!(return_feedback, feedback[..1]); + // List feedback by item and type. + let return_feedback = client.list_feedback_by_item_and_type("300", "read")?; + assert_eq!(return_feedback, feedback[..1]); + // List feedback from user. + let return_feedback = client.list_feedback_from_user("1000")?; + assert_eq!(return_feedback, feedback[..2]); + // List feedback from user by item. + let return_feedback = client.list_feedback_from_user_by_item("1000", "300")?; + assert_eq!(return_feedback, feedback[..1]); + // List feedback from user by type. + let return_feedback = client.list_feedback_from_user_by_type("1000", "read")?; + assert_eq!(return_feedback, feedback[..2]); + Ok(()) + } + + #[test] + #[serial] + fn test_neighbors() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("user_neighbors/1000")?; + connection.del("item_neighbors/1000")?; + connection.del("item_neighbors/1000/test")?; + connection.zadd_multiple("user_neighbors/1000", &[(1, 1000), (2, 2000), (3, 3000)])?; + connection.zadd_multiple("item_neighbors/1000", &[(1, 1000), (2, 2000), (3, 3000)])?; + connection.zadd_multiple( + "item_neighbors/1000/test", + &[(1, 1000), (2, 2000), (3, 3000)], + )?; + let scores = vec![ + Score { + id: "3000".into(), + score: 3.0, + }, + Score { + id: "2000".into(), + score: 2.0, + }, + Score { + id: "1000".into(), + score: 1.0, + }, + ]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get item neighbors. + let returned_scores = client.get_item_neighbors("1000", &OffsetQuery::new())?; + assert_eq!(returned_scores, scores); + // Get item neighbors by category. + let returned_scores = + client.get_item_neighbors_by_category("1000", "test", &OffsetQuery::new())?; + assert_eq!(returned_scores, scores); + // Get user neighbors. + let returned_scores = client.get_user_neighbors("1000", &OffsetQuery::new())?; + assert_eq!(returned_scores, scores); + Ok(()) + } + + #[test] + #[serial] + fn test_latest() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("latest_items")?; + connection.del("latest_items/test")?; + connection.zadd_multiple("latest_items", &[(1, 1000), (2, 2000), (3, 3000)])?; + connection.zadd_multiple("latest_items/test", &[(1, 1000), (2, 2000), (3, 3000)])?; + let scores = vec![ + Score { + id: "3000".into(), + score: 3.0, + }, + Score { + id: "2000".into(), + score: 2.0, + }, + Score { + id: "1000".into(), + score: 1.0, + }, + ]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get latest. + let returned_scores = client.get_latest(&UserIdQuery::new())?; + assert_eq!(returned_scores, scores); + // Get latest by category. + let returned_scores = client.get_latest_by_category("test", &UserIdQuery::new())?; + assert_eq!(returned_scores, scores); + Ok(()) + } + + #[test] + #[serial] + fn test_popular() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("popular_items")?; + connection.del("popular_items/test")?; + connection.zadd_multiple("popular_items", &[(1, 1000), (2, 2000), (3, 3000)])?; + connection.zadd_multiple("popular_items/test", &[(1, 1000), (2, 2000), (3, 3000)])?; + let scores = vec![ + Score { + id: "3000".into(), + score: 3.0, + }, + Score { + id: "2000".into(), + score: 2.0, + }, + Score { + id: "1000".into(), + score: 1.0, + }, + ]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get popular. + let returned_scores = client.get_popular(&UserIdQuery::new())?; + assert_eq!(returned_scores, scores); + // Get popular by category. + let returned_scores = client.get_popular_by_category("test", &UserIdQuery::new())?; + assert_eq!(returned_scores, scores); + Ok(()) + } + + #[test] + #[serial] + fn test_recommend() -> Result<()> { + let redis = redis::Client::open("redis://127.0.0.1/")?; + let mut connection = redis.get_connection()?; + connection.del("offline_recommend/1000")?; + connection.del("offline_recommend/1000/test")?; + connection.zadd_multiple("offline_recommend/1000", &[(1, 1000), (2, 2000), (3, 3000)])?; + connection.zadd_multiple( + "offline_recommend/1000/test", + &[(1, 1000), (2, 2000), (3, 3000)], + )?; + let items = vec!["3000".to_string(), "2000".to_string(), "1000".to_string()]; + let client = Gorse::new(ENTRY_POINT, API_KEY); + // Get recommendation. + let returned_items = client.get_recommend("1000", &WriteBackQuery::new())?; + assert_eq!(returned_items, items); + // Get recommendation by category. + let returned_items = + client.get_recommend_by_category("1000", "test", &WriteBackQuery::new())?; + assert_eq!(returned_items, items); + Ok(()) + } + + #[test] + #[serial] + fn test_recommend_session() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let items = vec![Item::new( + "1009", + vec!["a", "b", "c"], + vec!["d", "e"], + "2022-11-20T13:55:27Z", + )]; + let feedbacks = vec![Feedback::new( + "read", + "1000", + "1009", + "2022-11-21T13:55:27Z", + )]; + // Insert an item. + let rows_affected = client.insert_items(&items)?; + assert_eq!(rows_affected.row_affected, 1); + // Get recommendation. + let returned_scores = client.get_recommend_session(&feedbacks, &OffsetQuery::new())?; + assert!(returned_scores.is_empty()); + // Get recommendation by category. + let returned_scores = + client.get_recommend_session_by_category(&feedbacks, "test", &OffsetQuery::new())?; + assert!(returned_scores.is_empty()); + // Delete a feedback. + let rows_affected = client.delete_feedback("read", "1000", "1009")?; + assert_eq!(rows_affected.row_affected, 0); + // Delete an item. + let rows_affected = client.delete_item("1009")?; + assert_eq!(rows_affected.row_affected, 1); + Ok(()) + } + + #[test] + #[serial] + fn test_health() -> Result<()> { + let client = Gorse::new(ENTRY_POINT, API_KEY); + let health = Health { + cache_store_connected: true, + cache_store_error: None, + data_store_connected: true, + data_store_error: None, + ready: true, + }; + let return_health = client.is_live()?; + assert_eq!(return_health, health); + let return_health = client.is_ready()?; + assert_eq!(return_health, health); + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index cc6d25c..a1d948c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,24 @@ -use std::{error, fmt}; +mod r#async; +pub mod blocking; +pub mod query; + use std::cmp::PartialEq; use std::fmt::{Display, Formatter}; +use std::{error, fmt}; +pub use r#async::*; use reqwest::{Method, StatusCode}; -use reqwest::Client; use serde::{Deserialize, Serialize}; type Result = std::result::Result>; -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub(crate) struct Users { + #[serde(rename = "Users")] + users: Vec, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct User { #[serde(rename = "UserId")] pub user_id: String, @@ -16,7 +26,22 @@ pub struct User { pub labels: Vec, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +impl User { + pub fn new(user_id: impl Into, labels: Vec>) -> Self { + User { + user_id: user_id.into(), + labels: labels.into_iter().map(|label| label.into()).collect(), + } + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub(crate) struct Items { + #[serde(rename = "Items")] + items: Vec, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct Item { #[serde(rename = "ItemId")] pub item_id: String, @@ -32,7 +57,44 @@ pub struct Item { pub comment: String, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +impl Item { + pub fn new( + item_id: impl Into, + labels: Vec>, + categories: Vec>, + timestamp: impl Into, + ) -> Self { + Item { + item_id: item_id.into(), + is_hidden: false, + labels: labels.into_iter().map(|label| label.into()).collect(), + categories: categories + .into_iter() + .map(|category| category.into()) + .collect(), + timestamp: timestamp.into(), + comment: String::new(), + } + } + + pub fn is_hidden(mut self, is_hidden: bool) -> Self { + self.is_hidden = is_hidden; + self + } + + pub fn comment(mut self, comment: impl Into) -> Self { + self.comment = comment.into(); + self + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub(crate) struct Feedbacks { + #[serde(rename = "Feedback")] + feedbacks: Vec, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct Feedback { #[serde(rename = "FeedbackType")] pub feedback_type: String, @@ -45,23 +107,42 @@ pub struct Feedback { } impl Feedback { - pub fn new(feedback_type: impl Into, user_id: impl Into, item_id: impl Into, timestamp: impl Into) -> Self { - return Feedback { + pub fn new( + feedback_type: impl Into, + user_id: impl Into, + item_id: impl Into, + timestamp: impl Into, + ) -> Self { + Feedback { feedback_type: feedback_type.into(), user_id: user_id.into(), item_id: item_id.into(), timestamp: timestamp.into(), - }; + } } } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub struct Health { + #[serde(rename = "CacheStoreConnected")] + pub cache_store_connected: bool, + #[serde(rename = "CacheStoreError")] + pub cache_store_error: Option, + #[serde(rename = "DataStoreConnected")] + pub data_store_connected: bool, + #[serde(rename = "DataStoreError")] + pub data_store_error: Option, + #[serde(rename = "Ready")] + pub ready: bool, +} + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct RowAffected { #[serde(rename = "RowAffected")] pub row_affected: i32, } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub struct Score { #[serde(rename = "Id")] pub id: String, @@ -86,340 +167,3 @@ impl error::Error for Error { &self.message } } - -#[derive(Debug, Clone)] -pub struct Gorse { - entry_point: String, - api_key: String, - client: Client, -} - -impl Gorse { - pub fn new(entry_point: impl Into, api_key: impl Into) -> Self { - Self { - entry_point: entry_point.into(), - api_key: api_key.into(), - client: Client::new(), - } - } - - pub async fn insert_user(&self, user: &User) -> Result { - return self.request(Method::POST, format!("{}api/user", self.entry_point), user).await; - } - - pub async fn get_user(&self, user_id: &str) -> Result { - return self.request::<(), User>(Method::GET, format!("{}api/user/{}", self.entry_point, user_id), &()).await; - } - - pub async fn delete_user(&self, user_id: &str) -> Result { - return self.request::<(), RowAffected>(Method::DELETE, format!("{}api/user/{}", self.entry_point, user_id), &()).await; - } - - pub async fn insert_item(&self, item: &Item) -> Result { - return self.request(Method::POST, format!("{}api/item", self.entry_point), item).await; - } - - pub async fn get_item(&self, item_id: &str) -> Result { - return self.request::<(), Item>(Method::GET, format!("{}api/item/{}", self.entry_point, item_id), &()).await; - } - - pub async fn delete_item(&self, item_id: &str) -> Result { - return self.request::<(), RowAffected>(Method::DELETE, format!("{}api/item/{}", self.entry_point, item_id), &()).await; - } - - pub async fn insert_feedback(&self, feedback: &Vec) -> Result { - return self.request(Method::POST, format!("{}api/feedback", self.entry_point), feedback).await; - } - - pub async fn list_feedback(&self, user_id: &str, feedback_type: &str) -> Result> { - return self.request::<(), Vec>(Method::GET, format!("{}api/user/{}/feedback/{}", self.entry_point, user_id, feedback_type), &()).await; - } - - pub async fn get_item_neighbors(&self, item_id: &str) -> Result> { - return self.request::<(), Vec>(Method::GET, format!("{}api/item/{}/neighbors", self.entry_point, item_id), &()).await; - } - - pub async fn get_recommend(&self, user_id: &str) -> Result> { - return self.request::<(), Vec>(Method::GET, format!("{}api/recommend/{}", self.entry_point, user_id), &()).await; - } - - async fn request Deserialize<'a>>(&self, method: Method, url: String, body: &BodyType) -> Result { - let response = self.client.request(method, url) - .header("X-API-Key", self.api_key.as_str()) - .header("Content-Type", "application/json") - .json(body) - .send().await?; - return if response.status() == StatusCode::OK { - let r: RetType = serde_json::from_str(response.text().await?.as_str())?; - Ok(r) - } else { - Err(Box::new(Error { status_code: response.status(), message: response.text().await? })) - }; - } -} - -#[cfg(test)] -mod tests { - use redis::Commands; - - use super::{*}; - - const ENTRY_POINT: &str = "http://127.0.0.1:8088/"; - const API_KEY: &str = "zhenghaoz"; - - #[tokio::test] - async fn test_users() -> Result<()> { - let client = Gorse::new(ENTRY_POINT, API_KEY); - let user = User { user_id: "1".into(), labels: vec!["a".into(), "b".into(), "c".into()] }; - // Insert a user. - let rows_affected = client.insert_user(&user).await?; - assert_eq!(rows_affected.row_affected, 1); - // Get this user. - let return_user = client.get_user("1").await?; - assert_eq!(return_user, user); - // Delete this user. - let rows_affected = client.delete_user("1").await?; - assert_eq!(rows_affected.row_affected, 1); - let response = client.get_user("1").await; - assert!(response.is_err()); - Ok(()) - } - - #[tokio::test] - async fn test_items() -> Result<()> { - let client = Gorse::new(ENTRY_POINT, API_KEY); - let item = Item { - item_id: "1".into(), - is_hidden: true, - labels: vec!["a".into(), "b".into(), "c".into()], - categories: vec!["d".into(), "e".into()], - timestamp: "2022-11-20T13:55:27Z".into(), - comment: "comment".into(), - }; - // Insert an item. - let rows_affected = client.insert_item(&item).await?; - assert_eq!(rows_affected.row_affected, 1); - // Get this item. - let return_item = client.get_item("1").await?; - assert_eq!(return_item, item); - // Delete this item. - let rows_affected = client.delete_item("1").await?; - assert_eq!(rows_affected.row_affected, 1); - let response = client.get_item("1").await; - assert!(response.is_err()); - Ok(()) - } - - #[tokio::test] - async fn test_feedback() -> Result<()> { - let client = Gorse::new(ENTRY_POINT, API_KEY); - let feedback = vec![ - Feedback::new("read", "10", "3", "2022-11-20T13:55:27Z"), - Feedback::new("read", "10", "4", "2022-11-20T13:55:27Z"), - ]; - let rows_affected = client.insert_feedback(&feedback).await?; - assert_eq!(rows_affected.row_affected, 2); - // Insert feedback. - let return_feedback = client.list_feedback("10", "read").await?; - assert_eq!(return_feedback, feedback); - Ok(()) - } - - #[tokio::test] - async fn test_neighbors() -> Result<()> { - let redis = redis::Client::open("redis://127.0.0.1/")?; - let mut connection = redis.get_connection()?; - connection.zadd_multiple("item_neighbors/10", &[(1, 10), (2, 20), (3, 30)])?; - let client = Gorse::new(ENTRY_POINT, API_KEY); - let scores = client.get_item_neighbors("10").await?; - assert_eq!(scores, vec![ - Score { id: "30".into(), score: 3.0 }, - Score { id: "20".into(), score: 2.0 }, - Score { id: "10".into(), score: 1.0 }, - ]); - Ok(()) - } - - #[tokio::test] - async fn test_recommend() -> Result<()> { - let redis = redis::Client::open("redis://127.0.0.1/")?; - let mut connection = redis.get_connection()?; - connection.zadd_multiple("offline_recommend/10", &[(1, 10), (2, 20), (3, 30)])?; - let client = Gorse::new(ENTRY_POINT, API_KEY); - let items = client.get_recommend("10").await?; - assert_eq!(items, vec!["30".to_string(), "20".to_string(), "10".to_string()]); - Ok(()) - } -} - -pub mod blocking { - use reqwest::blocking::Client; - use serde::{Deserialize, Serialize}; - - use crate::{Error, Feedback, Item, Method, Result, RowAffected, Score, StatusCode, User}; - - #[derive(Debug, Clone)] - pub struct Gorse { - entry_point: String, - api_key: String, - client: Client, - } - - impl Gorse { - pub fn new(entry_point: impl Into, api_key: impl Into) -> Self { - Self { - entry_point: entry_point.into(), - api_key: api_key.into(), - client: Client::new(), - } - } - - pub fn insert_user(&self, user: &User) -> Result { - return self.request(Method::POST, format!("{}api/user", self.entry_point), user); - } - - pub fn get_user(&self, user_id: &str) -> Result { - return self.request::<(), User>(Method::GET, format!("{}api/user/{}", self.entry_point, user_id), &()); - } - - pub fn delete_user(&self, user_id: &str) -> Result { - return self.request::<(), RowAffected>(Method::DELETE, format!("{}api/user/{}", self.entry_point, user_id), &()); - } - - pub fn insert_item(&self, item: &Item) -> Result { - return self.request(Method::POST, format!("{}api/item", self.entry_point), item); - } - - pub fn get_item(&self, item_id: &str) -> Result { - return self.request::<(), Item>(Method::GET, format!("{}api/item/{}", self.entry_point, item_id), &()); - } - - pub fn delete_item(&self, item_id: &str) -> Result { - return self.request::<(), RowAffected>(Method::DELETE, format!("{}api/item/{}", self.entry_point, item_id), &()); - } - - pub fn insert_feedback(&self, feedback: &Vec) -> Result { - return self.request(Method::POST, format!("{}api/feedback", self.entry_point), feedback); - } - - pub fn list_feedback(&self, user_id: &str, feedback_type: &str) -> Result> { - return self.request::<(), Vec>(Method::GET, format!("{}api/user/{}/feedback/{}", self.entry_point, user_id, feedback_type), &()); - } - - pub fn get_item_neighbors(&self, item_id: &str) -> Result> { - return self.request::<(), Vec>(Method::GET, format!("{}api/item/{}/neighbors", self.entry_point, item_id), &()); - } - - pub fn get_recommend(&self, user_id: &str) -> Result> { - return self.request::<(), Vec>(Method::GET, format!("{}api/recommend/{}", self.entry_point, user_id), &()); - } - - fn request Deserialize<'a>>(&self, method: Method, url: String, body: &BodyType) -> Result { - let response = self.client.request(method, url) - .header("X-API-Key", self.api_key.as_str()) - .header("Content-Type", "application/json") - .json(body) - .send()?; - return if response.status() == StatusCode::OK { - let r: RetType = serde_json::from_str(response.text()?.as_str())?; - Ok(r) - } else { - Err(Box::new(Error { status_code: response.status(), message: response.text()? })) - }; - } - } - - #[cfg(test)] - mod tests { - use redis::Commands; - - use super::{*}; - - const ENTRY_POINT: &str = "http://127.0.0.1:8088/"; - const API_KEY: &str = "zhenghaoz"; - - #[test] - fn test_users() -> Result<()> { - let client = Gorse::new(ENTRY_POINT, API_KEY); - let user = User { user_id: "100".into(), labels: vec!["a".into(), "b".into(), "c".into()] }; - // Insert a user. - let rows_affected = client.insert_user(&user)?; - assert_eq!(rows_affected.row_affected, 1); - // Get this user. - let return_user = client.get_user("100")?; - assert_eq!(return_user, user); - // Delete this user. - let rows_affected = client.delete_user("100")?; - assert_eq!(rows_affected.row_affected, 1); - let response = client.get_user("100"); - assert!(response.is_err()); - Ok(()) - } - - #[test] - fn test_items() -> Result<()> { - let client = Gorse::new(ENTRY_POINT, API_KEY); - let item = Item { - item_id: "100".into(), - is_hidden: true, - labels: vec!["a".into(), "b".into(), "c".into()], - categories: vec!["d".into(), "e".into()], - timestamp: "2022-11-20T13:55:27Z".into(), - comment: "comment".into(), - }; - // Insert an item. - let rows_affected = client.insert_item(&item)?; - assert_eq!(rows_affected.row_affected, 1); - // Get this item. - let return_item = client.get_item("100")?; - assert_eq!(return_item, item); - // Delete this item. - let rows_affected = client.delete_item("100")?; - assert_eq!(rows_affected.row_affected, 1); - let response = client.get_item("100"); - assert!(response.is_err()); - Ok(()) - } - - #[test] - fn test_feedback() -> Result<()> { - let client = Gorse::new(ENTRY_POINT, API_KEY); - let feedback = vec![ - Feedback::new("read", "1000", "300", "2022-11-20T13:55:27Z"), - Feedback::new("read", "1000", "400", "2022-11-20T13:55:27Z"), - ]; - let rows_affected = client.insert_feedback(&feedback)?; - assert_eq!(rows_affected.row_affected, 2); - // Insert feedback. - let return_feedback = client.list_feedback("1000", "read")?; - assert_eq!(return_feedback, feedback); - Ok(()) - } - - #[test] - fn test_neighbors() -> Result<()> { - let redis = redis::Client::open("redis://127.0.0.1/")?; - let mut connection = redis.get_connection()?; - connection.zadd_multiple("item_neighbors/1000", &[(1, 1000), (2, 2000), (3, 3000)])?; - let client = Gorse::new(ENTRY_POINT, API_KEY); - let scores = client.get_item_neighbors("1000")?; - assert_eq!(scores, vec![ - Score { id: "3000".into(), score: 3.0 }, - Score { id: "2000".into(), score: 2.0 }, - Score { id: "1000".into(), score: 1.0 }, - ]); - Ok(()) - } - - #[test] - fn test_recommend() -> Result<()> { - let redis = redis::Client::open("redis://127.0.0.1/")?; - let mut connection = redis.get_connection()?; - connection.zadd_multiple("offline_recommend/1000", &[(1, 1000), (2, 2000), (3, 3000)])?; - let client = Gorse::new(ENTRY_POINT, API_KEY); - let items = client.get_recommend("1000")?; - assert_eq!(items, vec!["3000".to_string(), "2000".to_string(), "1000".to_string()]); - Ok(()) - } - } -} diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 0000000..cc5931c --- /dev/null +++ b/src/query.rs @@ -0,0 +1,129 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct CursorQuery { + #[serde(rename = "n")] + pub number: Option, + pub cursor: Option, +} + +impl CursorQuery { + pub fn new() -> Self { + CursorQuery { + number: None, + cursor: None, + } + } + + pub fn number(mut self, number: i32) -> Self { + self.number = Some(number); + self + } + + pub fn cursor(mut self, cursor: impl Into) -> Self { + self.cursor = Some(cursor.into()); + self + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct OffsetQuery { + #[serde(rename = "n")] + pub number: Option, + pub offset: Option, +} + +impl OffsetQuery { + pub fn new() -> Self { + OffsetQuery { + number: None, + offset: None, + } + } + + pub fn number(mut self, number: i32) -> Self { + self.number = Some(number); + self + } + + pub fn offset(mut self, offset: i32) -> Self { + self.offset = Some(offset); + self + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct UserIdQuery { + #[serde(rename = "user-id")] + pub user_id: Option, + #[serde(rename = "n")] + pub number: Option, + pub offset: Option, +} + +impl UserIdQuery { + pub fn new() -> Self { + UserIdQuery { + user_id: None, + number: None, + offset: None, + } + } + + pub fn user_id(mut self, user_id: impl Into) -> Self { + self.user_id = Some(user_id.into()); + self + } + + pub fn number(mut self, number: i32) -> Self { + self.number = Some(number); + self + } + + pub fn offset(mut self, offset: i32) -> Self { + self.offset = Some(offset); + self + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct WriteBackQuery { + #[serde(rename = "write-back-type")] + pub write_back_type: Option, + #[serde(rename = "write-back-delay")] + pub write_back_delay: Option, + #[serde(rename = "n")] + pub number: Option, + pub offset: Option, +} + +impl WriteBackQuery { + pub fn new() -> Self { + WriteBackQuery { + write_back_type: None, + write_back_delay: None, + number: None, + offset: None, + } + } + + pub fn write_back_type(mut self, write_back_type: impl Into) -> Self { + self.write_back_type = Some(write_back_type.into()); + self + } + + pub fn write_back_delay(mut self, write_back_delay: impl Into) -> Self { + self.write_back_delay = Some(write_back_delay.into()); + self + } + + pub fn number(mut self, number: i32) -> Self { + self.number = Some(number); + self + } + + pub fn offset(mut self, offset: i32) -> Self { + self.offset = Some(offset); + self + } +}