From b5e70df305029d8be726b821ef98923fb56712de Mon Sep 17 00:00:00 2001 From: Omar Shehab <531omarhero@gmail.com> Date: Sun, 1 Sep 2024 20:13:55 +0300 Subject: [PATCH] Make all structs serializable --- Cargo.toml | 2 +- src/adapter/adapter_kind.rs | 3 ++- src/adapter/adapter_types.rs | 5 +++-- src/adapter/inter_stream.rs | 5 ++++- src/chat/chat_options.rs | 5 ++++- src/chat/chat_req.rs | 11 ++++++----- src/chat/chat_res.rs | 6 ++++-- src/chat/chat_stream.rs | 9 +++++---- src/chat/message_content.rs | 4 +++- src/chat/printer.rs | 3 ++- src/chat/tool.rs | 2 ++ src/common/model_iden.rs | 4 +++- src/common/model_name.rs | 4 +++- src/webc/web_client.rs | 3 ++- 14 files changed, 44 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 33f3c40..81552d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ tokio = { version = "1", features = ["full"] } futures = "0.3" tokio-stream = "0.1" # -- Json -serde = { version = "1", features = ["derive"] } +serde = { version = "1", features = ["derive", "rc"] } # Opted to rc for Arc serialization serde_json = "1" # -- Web reqwest = {version = "0.12", features = ["json"]} diff --git a/src/adapter/adapter_kind.rs b/src/adapter/adapter_kind.rs index 8c98e2f..fbe8b7d 100644 --- a/src/adapter/adapter_kind.rs +++ b/src/adapter/adapter_kind.rs @@ -1,8 +1,9 @@ use super::groq::MODELS as GROQ_MODELS; use crate::Result; use derive_more::Display; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Copy, Display, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Display, Eq, PartialEq, Hash, Serialize, Deserialize)] pub enum AdapterKind { OpenAI, Ollama, diff --git a/src/adapter/adapter_types.rs b/src/adapter/adapter_types.rs index ff4f3a0..c23fd1b 100644 --- a/src/adapter/adapter_types.rs +++ b/src/adapter/adapter_types.rs @@ -4,6 +4,7 @@ use crate::webc::WebResponse; use crate::Result; use crate::{ClientConfig, ModelIden}; use reqwest::RequestBuilder; +use serde::{Deserialize, Serialize}; use serde_json::Value; pub trait Adapter { @@ -38,7 +39,7 @@ pub trait Adapter { // region: --- ServiceType -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum ServiceType { Chat, ChatStream, @@ -49,7 +50,7 @@ pub enum ServiceType { // region: --- WebRequestData // NOTE: This cannot really move to `webc` because it has to be public with the adapter and `webc` is private for now. - +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct WebRequestData { pub url: String, pub headers: Vec<(String, String)>, diff --git a/src/adapter/inter_stream.rs b/src/adapter/inter_stream.rs index cc7ed29..073fbb8 100644 --- a/src/adapter/inter_stream.rs +++ b/src/adapter/inter_stream.rs @@ -5,9 +5,11 @@ //! //! NOTE: This might be removed at some point as it might not be needed, and going directly to the genai stream. +use serde::{Deserialize, Serialize}; + use crate::chat::MetaUsage; -#[derive(Debug, Default)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct InterStreamEnd { // When `ChatOptions..capture_usage == true` pub captured_usage: Option, @@ -17,6 +19,7 @@ pub struct InterStreamEnd { } /// Intermediary StreamEvent +#[derive(Debug, Serialize, Deserialize)] pub enum InterStreamEvent { Start, Chunk(String), diff --git a/src/chat/chat_options.rs b/src/chat/chat_options.rs index 630a7f1..260a7bd 100644 --- a/src/chat/chat_options.rs +++ b/src/chat/chat_options.rs @@ -4,7 +4,10 @@ //! //! Note 1: Later, we will probably allow to set the client //! Note 2: Splitting it out of the `ChatRequest` object allows for better reusability of each component. -#[derive(Debug, Clone, Default)] + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ChatOptions { /// Will be set for this request if Adapter/providers supports it. pub temperature: Option, diff --git a/src/chat/chat_req.rs b/src/chat/chat_req.rs index dccd4e5..befef98 100644 --- a/src/chat/chat_req.rs +++ b/src/chat/chat_req.rs @@ -1,10 +1,11 @@ //! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file). use crate::chat::MessageContent; +use serde::{Serialize, Deserialize}; // region: --- ChatRequest -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ChatRequest { pub system: Option, pub messages: Vec, @@ -84,7 +85,7 @@ impl ChatRequest { // region: --- ChatMessage -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: ChatRole, pub content: MessageContent, @@ -118,7 +119,7 @@ impl ChatMessage { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ChatRole { System, User, @@ -126,13 +127,13 @@ pub enum ChatRole { Tool, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum MessageExtra { Tool(ToolExtra), } #[allow(unused)] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolExtra { tool_id: String, } diff --git a/src/chat/chat_res.rs b/src/chat/chat_res.rs index 9c8fcbd..378c866 100644 --- a/src/chat/chat_res.rs +++ b/src/chat/chat_res.rs @@ -1,11 +1,13 @@ //! This module contains all the types related to a Chat Response (except ChatStream which has it file). +use serde::{Deserialize, Serialize}; + use crate::chat::{ChatStream, MessageContent}; use crate::ModelIden; // region: --- ChatResponse -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatResponse { pub content: Option, pub usage: MetaUsage, @@ -41,7 +43,7 @@ pub struct ChatStreamResponse { // region: --- MetaUsage /// IMPORTANT: This is **NOT SUPPORTED** for now. To show the API direction. -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct MetaUsage { pub input_tokens: Option, pub output_tokens: Option, diff --git a/src/chat/chat_stream.rs b/src/chat/chat_stream.rs index e0be127..61717a6 100644 --- a/src/chat/chat_stream.rs +++ b/src/chat/chat_stream.rs @@ -2,12 +2,13 @@ use crate::adapter::inter_stream::{InterStreamEnd, InterStreamEvent}; use crate::chat::{MessageContent, MetaUsage}; use derive_more::From; use futures::Stream; +use serde::{Deserialize, Serialize}; use std::pin::Pin; use std::task::{Context, Poll}; type InterStreamType = Pin> + Send>>; -/// ChatStream is a Rust Future Stream that iterates through the events of a chat stream request. +/// ChatStream is a Rust Future Stream that iterates through the events of a chat stream request pub struct ChatStream { inter_stream: InterStreamType, } @@ -54,19 +55,19 @@ impl Stream for ChatStream { // region: --- ChatStreamEvent -#[derive(Debug, From)] +#[derive(Debug, From, Serialize, Deserialize)] pub enum ChatStreamEvent { Start, Chunk(StreamChunk), End(StreamEnd), } -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct StreamChunk { pub content: String, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct StreamEnd { /// The eventual captured UsageMeta pub captured_usage: Option, diff --git a/src/chat/message_content.rs b/src/chat/message_content.rs index 8b0df95..f0fb849 100644 --- a/src/chat/message_content.rs +++ b/src/chat/message_content.rs @@ -1,6 +1,8 @@ +use serde::{Deserialize, Serialize}; + /// For now, supports only Text, /// But the goal is to support multi-part message content (see below) -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum MessageContent { Text(String), } diff --git a/src/chat/printer.rs b/src/chat/printer.rs index 16491c3..b7c719e 100644 --- a/src/chat/printer.rs +++ b/src/chat/printer.rs @@ -1,10 +1,11 @@ use crate::chat::{ChatStreamEvent, ChatStreamResponse, StreamChunk}; use futures::StreamExt; +use serde::{Deserialize, Serialize}; use tokio::io::{AsyncWriteExt as _, Stdout}; // region: --- PrintChatOptions -#[derive(Debug, Default)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct PrintChatStreamOptions { print_events: Option, } diff --git a/src/chat/tool.rs b/src/chat/tool.rs index 968bd06..c53d455 100644 --- a/src/chat/tool.rs +++ b/src/chat/tool.rs @@ -1,6 +1,8 @@ +use serde::{Deserialize, Serialize}; use serde_json::Value; #[allow(unused)] // Not used yet +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { fn_name: String, fn_description: String, diff --git a/src/common/model_iden.rs b/src/common/model_iden.rs index 96d391b..3faba44 100644 --- a/src/common/model_iden.rs +++ b/src/common/model_iden.rs @@ -1,9 +1,11 @@ +use serde::{Deserialize, Serialize}; + use crate::adapter::AdapterKind; use crate::ModelName; /// Hold the adapter_kind and model_name in a efficient clonable way /// Note: For now, -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ModelIden { pub adapter_kind: AdapterKind, pub model_name: ModelName, diff --git a/src/common/model_name.rs b/src/common/model_name.rs index c02c952..251b438 100644 --- a/src/common/model_name.rs +++ b/src/common/model_name.rs @@ -1,7 +1,9 @@ use std::ops::Deref; use std::sync::Arc; -#[derive(Clone, Debug)] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ModelName(Arc); impl std::fmt::Display for ModelName { diff --git a/src/webc/web_client.rs b/src/webc/web_client.rs index 9ee2b27..91866e2 100644 --- a/src/webc/web_client.rs +++ b/src/webc/web_client.rs @@ -1,6 +1,7 @@ use crate::webc::{Error, Result}; use reqwest::header::HeaderMap; use reqwest::{Method, RequestBuilder, StatusCode}; +use serde::{Deserialize, Serialize}; use serde_json::Value; /// Simple reqwest client wrapper for this library. @@ -73,7 +74,7 @@ impl WebClient { // NOTE: This is not none-stream web response (assume json for this lib) // Streaming is handled with event-source or custom stream (for Cohere for example) -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct WebResponse { #[allow(unused)] pub status: StatusCode,