From 263c533f629a55fc190075905c5fc7a865451579 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 31 Dec 2024 15:12:16 +0000 Subject: [PATCH 1/5] add changes --- atoma-proxy/src/server/handlers/embeddings.rs | 16 ++++------------ .../src/server/handlers/image_generations.rs | 4 ++-- atoma-proxy/src/server/middleware.rs | 2 +- atoma-proxy/src/server/types.rs | 5 ++++- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/atoma-proxy/src/server/handlers/embeddings.rs b/atoma-proxy/src/server/handlers/embeddings.rs index ca3518c..dd02522 100644 --- a/atoma-proxy/src/server/handlers/embeddings.rs +++ b/atoma-proxy/src/server/handlers/embeddings.rs @@ -273,18 +273,10 @@ pub async fn confidential_embeddings_create( // with a "total_tokens" field, which correctly specifies the number of total tokens // processed by the node, as the latter is running within a TEE. let total_tokens = response - .get("total_tokens") - .map(|u| { - u.as_u64().ok_or_else(|| AtomaProxyError::InternalError { - message: "Failed to get total tokens".to_string(), - endpoint: metadata.endpoint.clone(), - }) - }) - .transpose() - .map_err(|e| AtomaProxyError::InternalError { - message: format!("Failed to get total tokens: {}", e), - endpoint: metadata.endpoint.clone(), - })? + .get("usage") + .and_then(|usage| usage.get("total_tokens")) + .and_then(|total_tokens| total_tokens.as_u64()) + .map(|n| n as i64) .unwrap_or(0); update_state_manager( &state.state_manager_sender, diff --git a/atoma-proxy/src/server/handlers/image_generations.rs b/atoma-proxy/src/server/handlers/image_generations.rs index 6a61d4e..92e1850 100644 --- a/atoma-proxy/src/server/handlers/image_generations.rs +++ b/atoma-proxy/src/server/handlers/image_generations.rs @@ -249,8 +249,8 @@ pub async fn confidential_image_generations_create( .await { Ok(response) => { - // NOTE: At this point, we do not need to update the stack num tokens, - // because the image generation response was correctly generated. + // NOTE: At this point, we do not need to update the stack num compute units, + // because the image generation response was correctly generated by a TEE node. Ok(response.into_response()) } Err(e) => { diff --git a/atoma-proxy/src/server/middleware.rs b/atoma-proxy/src/server/middleware.rs index 50ab60f..c72c3b3 100644 --- a/atoma-proxy/src/server/middleware.rs +++ b/atoma-proxy/src/server/middleware.rs @@ -420,7 +420,7 @@ pub async fn confidential_compute_middleware( let num_compute_units = if endpoint == CONFIDENTIAL_IMAGE_GENERATIONS_PATH { confidential_compute_request - .max_tokens + .num_compute_units .unwrap_or(DEFAULT_IMAGE_RESOLUTION) as i64 } else { MAX_NUM_TOKENS_FOR_CONFIDENTIAL_COMPUTE diff --git a/atoma-proxy/src/server/types.rs b/atoma-proxy/src/server/types.rs index c8844fe..79fab1e 100644 --- a/atoma-proxy/src/server/types.rs +++ b/atoma-proxy/src/server/types.rs @@ -19,6 +19,9 @@ pub struct ConfidentialComputeRequest { /// Client's public key for Diffie-Hellman key exchange (base64 encoded) pub client_dh_public_key: String, + /// Node's public key for Diffie-Hellman key exchange (base64 encoded) + pub node_dh_public_key: String, + /// Hash of the original plaintext body for integrity verification (base64 encoded) pub plaintext_body_hash: String, @@ -30,5 +33,5 @@ pub struct ConfidentialComputeRequest { /// Number of compute units to be used for the request, for image generations, /// as this value is known in advance (the number of pixels to generate) - pub max_tokens: Option, + pub num_compute_units: Option, } From a7444c7c2a3f635542cfa94ae3598925bb2cdcbd Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 31 Dec 2024 17:49:23 +0000 Subject: [PATCH 2/5] add changes --- atoma-proxy/src/server/handlers/mod.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/atoma-proxy/src/server/handlers/mod.rs b/atoma-proxy/src/server/handlers/mod.rs index 0182bd6..4849ccc 100644 --- a/atoma-proxy/src/server/handlers/mod.rs +++ b/atoma-proxy/src/server/handlers/mod.rs @@ -1,5 +1,6 @@ use atoma_state::types::AtomaAtomaStateManagerEvent; use flume::Sender; +use tracing::instrument; use super::error::AtomaProxyError; use crate::server::Result; @@ -34,6 +35,11 @@ pub mod select_node_public_key; /// This function will return an error if: /// - The state manager channel is closed /// - Either update operation fails to complete +#[instrument( + level = "info", + skip_all, + fields(stack_small_id, estimated_total_tokens, total_tokens, endpoint) +)] pub fn update_state_manager( state_manager_sender: &Sender, stack_small_id: i64, From d4903a88365de4e292e4b299ab1b494170e712fe Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Wed, 1 Jan 2025 11:28:18 +0000 Subject: [PATCH 3/5] first commit --- Cargo.lock | 1 + Cargo.toml | 2 +- atoma-auth/src/sui/mod.rs | 9 ++ atoma-proxy/Cargo.toml | 1 + .../src/server/handlers/chat_completions.rs | 12 +- atoma-proxy/src/server/handlers/embeddings.rs | 13 +- .../src/server/handlers/image_generations.rs | 10 +- atoma-proxy/src/server/handlers/mod.rs | 134 ++++++++++++++++++ atoma-proxy/src/server/streamer.rs | 17 ++- 9 files changed, 189 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b6598db..2ba98d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -604,6 +604,7 @@ dependencies = [ "blake2", "clap", "config", + "fastcrypto 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "flume", "futures", "hf-hub", diff --git a/Cargo.toml b/Cargo.toml index d9bf39d..89735b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ chrono = "0.4.38" clap = "4.5.20" config = "0.14.1" dcap-qvl = "0.1.6" +fastcrypto = "0.1.8" flume = "0.11.1" futures = "0.3.31" hf-hub = "0.3.2" @@ -53,5 +54,4 @@ utoipa-swagger-ui = "8.0.3" uuid = "1.11.0" x25519-dalek = "2.0.1" zeroize = "1.8.1" -fastcrypto = "0.1.8" tower-http = "0.6.2" diff --git a/atoma-auth/src/sui/mod.rs b/atoma-auth/src/sui/mod.rs index cb30a60..ae642fe 100644 --- a/atoma-auth/src/sui/mod.rs +++ b/atoma-auth/src/sui/mod.rs @@ -215,6 +215,15 @@ impl Sui { Ok(signature.encode_base64()) } + /// Get the underlying keystore + /// + /// # Returns + /// + /// Returns the keystore. + pub fn get_keystore(&self) -> &Keystore { + &self.wallet_ctx.config.keystore + } + /// Sign a hash using the wallet's private key /// /// # Arguments diff --git a/atoma-proxy/Cargo.toml b/atoma-proxy/Cargo.toml index 782361e..c736031 100644 --- a/atoma-proxy/Cargo.toml +++ b/atoma-proxy/Cargo.toml @@ -17,6 +17,7 @@ base64 = { workspace = true } blake2.workspace = true clap.workspace = true config.workspace = true +fastcrypto.workspace = true flume.workspace = true futures = { workspace = true } hf-hub = { workspace = true } diff --git a/atoma-proxy/src/server/handlers/chat_completions.rs b/atoma-proxy/src/server/handlers/chat_completions.rs index 68d8398..18d395e 100644 --- a/atoma-proxy/src/server/handlers/chat_completions.rs +++ b/atoma-proxy/src/server/handlers/chat_completions.rs @@ -16,7 +16,7 @@ use tracing::instrument; use utoipa::{OpenApi, ToSchema}; use super::request_model::RequestModel; -use super::update_state_manager; +use super::{update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY}; use crate::server::Result; /// Path for the confidential chat completions endpoint. @@ -464,7 +464,7 @@ async fn handle_non_streaming_response( let client = reqwest::Client::new(); let time = Instant::now(); - let response = client + let mut response = client .post(format!("{}{}", node_address, endpoint)) .headers(headers) .json(&payload) @@ -504,6 +504,12 @@ async fn handle_non_streaming_response( .map(|n| n as i64) .unwrap_or(0); + let guard = state.sui.read().await; + let keystore = guard.get_keystore(); + let proxy_signature = verify_and_sign_response(&response.0, keystore)?; + + response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature); + state .state_manager_sender .send( @@ -622,12 +628,14 @@ async fn handle_streaming_response( let stream = response.bytes_stream(); + let guard = state.sui.read().await; // Create the SSE stream let stream = Sse::new(Streamer::new( stream, state.state_manager_sender.clone(), selected_stack_small_id, estimated_total_tokens, + guard, start, node_id, model_name, diff --git a/atoma-proxy/src/server/handlers/embeddings.rs b/atoma-proxy/src/server/handlers/embeddings.rs index dd02522..6c27186 100644 --- a/atoma-proxy/src/server/handlers/embeddings.rs +++ b/atoma-proxy/src/server/handlers/embeddings.rs @@ -19,7 +19,10 @@ use crate::server::{ types::ConfidentialComputeRequest, }; -use super::{request_model::RequestModel, update_state_manager}; +use super::{ + request_model::RequestModel, update_state_manager, verify_and_sign_response, + PROXY_SIGNATURE_KEY, +}; use crate::server::Result; /// Path for the confidential embeddings endpoint. @@ -349,7 +352,7 @@ async fn handle_embeddings_response( let client = reqwest::Client::new(); let time = Instant::now(); // Send the request to the AI node - let response = client + let mut response = client .post(format!("{}{}", node_address, endpoint)) .headers(headers) .json(&payload) @@ -366,6 +369,12 @@ async fn handle_embeddings_response( endpoint: endpoint.to_string(), })?; + let guard = state.sui.read().await; + let keystore = guard.get_keystore(); + let proxy_signature = verify_and_sign_response(&response, keystore)?; + + response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature); + let num_input_compute_units = if endpoint == CONFIDENTIAL_EMBEDDINGS_PATH { response .get("total_tokens") diff --git a/atoma-proxy/src/server/handlers/image_generations.rs b/atoma-proxy/src/server/handlers/image_generations.rs index 92e1850..3f1caa5 100644 --- a/atoma-proxy/src/server/handlers/image_generations.rs +++ b/atoma-proxy/src/server/handlers/image_generations.rs @@ -16,7 +16,7 @@ use crate::server::types::ConfidentialComputeRequest; use crate::server::{http_server::ProxyState, middleware::RequestMetadataExtension}; use super::request_model::RequestModel; -use super::update_state_manager; +use super::{update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY}; use crate::server::Result; /// Path for the confidential image generations endpoint. @@ -318,7 +318,7 @@ async fn handle_image_generation_response( let client = reqwest::Client::new(); let time = Instant::now(); // Send the request to the AI node - let response = client + let mut response = client .post(format!("{}{}", node_address, endpoint)) .headers(headers) .json(&payload) @@ -336,6 +336,12 @@ async fn handle_image_generation_response( }) .map(Json)?; + let guard = state.sui.read().await; + let keystore = guard.get_keystore(); + let proxy_signature = verify_and_sign_response(&response.0, keystore)?; + + response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature); + // Update the node throughput performance state .state_manager_sender diff --git a/atoma-proxy/src/server/handlers/mod.rs b/atoma-proxy/src/server/handlers/mod.rs index 4849ccc..a9bfa10 100644 --- a/atoma-proxy/src/server/handlers/mod.rs +++ b/atoma-proxy/src/server/handlers/mod.rs @@ -1,5 +1,16 @@ +use std::str::FromStr; + use atoma_state::types::AtomaAtomaStateManagerEvent; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use fastcrypto::{ + ed25519::{Ed25519PublicKey, Ed25519Signature}, + secp256k1::{Secp256k1PublicKey, Secp256k1Signature}, + secp256r1::{Secp256r1PublicKey, Secp256r1Signature}, + traits::{ToFromBytes, VerifyingKey}, +}; use flume::Sender; +use sui_keys::keystore::{AccountKeystore, Keystore}; +use sui_sdk::types::crypto::{PublicKey, Signature, SignatureScheme, SuiSignature}; use tracing::instrument; use super::error::AtomaProxyError; @@ -11,6 +22,15 @@ pub mod image_generations; pub mod request_model; pub mod select_node_public_key; +/// Key for the proxy signature in the payload +pub const PROXY_SIGNATURE_KEY: &str = "proxy_signature"; + +/// Key for the response hash in the payload +const RESPONSE_HASH_KEY: &str = "response_hash"; + +/// Key for the signature in the payload +const SIGNATURE_KEY: &str = "signature"; + /// Updates the state manager with token usage and hash information for a stack. /// /// This function performs two main operations: @@ -60,3 +80,117 @@ pub fn update_state_manager( })?; Ok(()) } + +/// Verifies a Sui signature and creates a new signature using the proxy's key +/// +/// # Arguments +/// +/// * `payload` - JSON payload containing the response hash and its signature +/// * `node_public_key` - Public key of the node that signed the response +/// * `proxy_keystore` - Keystore containing the proxy's signing key +/// +/// # Returns +/// +/// Returns `Ok(String)` with the new signature if verification succeeds, +/// or an error if verification fails or signing fails +/// +/// # Errors +/// +/// This function will return an error if: +/// - The payload format is invalid +/// - The signature verification fails +/// - Creating the new signature fails +#[instrument(level = "debug", skip_all)] +pub fn verify_and_sign_response( + payload: &serde_json::Value, + keystore: &Keystore, +) -> Result { + // Extract response hash and signature from payload + let response_hash = + payload[RESPONSE_HASH_KEY] + .as_str() + .ok_or_else(|| AtomaProxyError::InternalError { + message: "Missing response_hash in payload".to_string(), + endpoint: "verify_signature".to_string(), + })?; + + let node_signature = + payload[SIGNATURE_KEY] + .as_str() + .ok_or_else(|| AtomaProxyError::InternalError { + message: "Missing signature in payload".to_string(), + endpoint: "verify_signature".to_string(), + })?; + + let signature = + Signature::from_str(node_signature).map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to create signature: {}", e), + endpoint: "verify_signature".to_string(), + })?; + + let public_key_bytes = signature.public_key_bytes(); + let public_key = + PublicKey::try_from_bytes(signature.scheme(), public_key_bytes).map_err(|e| { + AtomaProxyError::InternalError { + message: format!("Failed to create public key: {}", e), + endpoint: "verify_signature".to_string(), + } + })?; + + match signature.scheme() { + SignatureScheme::ED25519 => { + let public_key = Ed25519PublicKey::from_bytes(public_key.as_ref()).unwrap(); + let signature = Ed25519Signature::from_bytes(signature.as_ref()).unwrap(); + public_key + .verify(response_hash.as_bytes(), &signature) + .map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to verify signature: {}", e), + endpoint: "verify_signature".to_string(), + })?; + } + SignatureScheme::Secp256k1 => { + let public_key = Secp256k1PublicKey::from_bytes(public_key.as_ref()).unwrap(); + let signature = Secp256k1Signature::from_bytes(signature.as_ref()).unwrap(); + public_key + .verify(response_hash.as_bytes(), &signature) + .map_err(|_| AtomaProxyError::InternalError { + message: "Failed to verify signature".to_string(), + endpoint: "verify_signature".to_string(), + })?; + } + SignatureScheme::Secp256r1 => { + let public_key = Secp256r1PublicKey::from_bytes(public_key.as_ref()).unwrap(); + let signature = Secp256r1Signature::from_bytes(signature.as_ref()).unwrap(); + public_key + .verify(response_hash.as_bytes(), &signature) + .map_err(|_| AtomaProxyError::InternalError { + message: "Failed to verify signature".to_string(), + endpoint: "verify_signature".to_string(), + })?; + } + _ => { + return Err(AtomaProxyError::InternalError { + message: "Currently unsupported signature scheme".to_string(), + endpoint: "verify_signature".to_string(), + }); + } + } + + // Sign with proxy's key + let proxy_signature = match keystore { + Keystore::File(keystore) => keystore + .sign_hashed(&keystore.addresses()[0], response_hash.as_bytes()) + .map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to create proxy signature: {}", e), + endpoint: "verify_signature".to_string(), + })?, + Keystore::InMem(keystore) => keystore + .sign_hashed(&keystore.addresses()[0], response_hash.as_bytes()) + .map_err(|e| AtomaProxyError::InternalError { + message: format!("Failed to create proxy signature: {}", e), + endpoint: "verify_signature".to_string(), + })?, + }; + // Convert signature to base64 + Ok(BASE64.encode(proxy_signature.as_ref())) +} diff --git a/atoma-proxy/src/server/streamer.rs b/atoma-proxy/src/server/streamer.rs index 3bc6e8d..3d2a925 100644 --- a/atoma-proxy/src/server/streamer.rs +++ b/atoma-proxy/src/server/streamer.rs @@ -1,9 +1,11 @@ use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, time::Instant, }; +use atoma_auth::Sui; use atoma_state::types::AtomaAtomaStateManagerEvent; use axum::body::Bytes; use axum::{response::sse::Event, Error}; @@ -12,10 +14,13 @@ use futures::Stream; use reqwest; use serde_json::Value; use sqlx::types::chrono::{DateTime, Utc}; +use tokio::sync::RwLockReadGuard; use tracing::{error, instrument}; use crate::server::handlers::{chat_completions::CHAT_COMPLETIONS_PATH, update_state_manager}; +use super::handlers::verify_and_sign_response; + /// The chunk that indicates the end of a streaming response const DONE_CHUNK: &str = "[DONE]"; @@ -32,13 +37,15 @@ const CHOICES: &str = "choices"; const USAGE: &str = "usage"; /// A structure for streaming chat completion chunks. -pub struct Streamer { +pub struct Streamer<'a> { /// The stream of bytes currently being processed stream: Pin> + Send>>, /// Current status of the stream status: StreamStatus, /// Estimated total tokens for the stream estimated_total_tokens: i64, + /// Keystore + keystore: RwLockReadGuard<'a, Sui>, /// Stack small id stack_small_id: i64, /// State manager sender @@ -68,7 +75,7 @@ pub enum StreamStatus { Failed(String), } -impl Streamer { +impl<'a> Streamer<'a> { /// Creates a new Streamer instance #[allow(clippy::too_many_arguments)] pub fn new( @@ -76,6 +83,7 @@ impl Streamer { state_manager_sender: Sender, stack_small_id: i64, estimated_total_tokens: i64, + keystore: RwLockReadGuard<'a, Sui>, start: Instant, node_id: i64, model_name: String, @@ -85,6 +93,7 @@ impl Streamer { stream: Box::pin(stream), status: StreamStatus::NotStarted, estimated_total_tokens, + keystore, stack_small_id, state_manager_sender, start, @@ -227,7 +236,7 @@ impl Streamer { } } -impl Stream for Streamer { +impl<'a> Stream for Streamer<'a> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -313,6 +322,8 @@ impl Stream for Streamer { } } else if let Some(usage) = chunk.get(USAGE) { self.status = StreamStatus::Completed; + verify_and_sign_response(&chunk, self.keystore.get_keystore()) + .map_err(|e| Error::new(e.to_string()))?; self.handle_final_chunk(usage)?; } From 5f25fa65da589155ccc3d1291c337d4330c47db2 Mon Sep 17 00:00:00 2001 From: chad Date: Mon, 10 Feb 2025 13:26:58 -0500 Subject: [PATCH 4/5] fix: use blocking reads for signature verification --- .../src/server/handlers/chat_completions.rs | 3 +-- atoma-proxy/src/server/streamer.rs | 22 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/atoma-proxy/src/server/handlers/chat_completions.rs b/atoma-proxy/src/server/handlers/chat_completions.rs index 815d721..7f7a50a 100644 --- a/atoma-proxy/src/server/handlers/chat_completions.rs +++ b/atoma-proxy/src/server/handlers/chat_completions.rs @@ -710,14 +710,13 @@ async fn handle_streaming_response( let stream = response.bytes_stream(); - let guard = state.sui.read().await; // Create the SSE stream let stream = Sse::new(Streamer::new( stream, state.state_manager_sender.clone(), selected_stack_small_id, estimated_total_tokens, - guard, + state.sui.clone(), start, node_id, model_name, diff --git a/atoma-proxy/src/server/streamer.rs b/atoma-proxy/src/server/streamer.rs index 65ce626..3813292 100644 --- a/atoma-proxy/src/server/streamer.rs +++ b/atoma-proxy/src/server/streamer.rs @@ -10,11 +10,12 @@ use futures::Stream; use reqwest; use serde_json::Value; use sqlx::types::chrono::{DateTime, Utc}; -use tokio::sync::RwLockReadGuard; +use tokio::sync::RwLock; use crate::server::handlers::{chat_completions::CHAT_COMPLETIONS_PATH, update_state_manager}; use super::handlers::verify_and_sign_response; +use std::sync::Arc; use std::{ pin::Pin, task::{Context, Poll}, @@ -43,7 +44,7 @@ const CHOICES: &str = "choices"; const USAGE: &str = "usage"; /// A structure for streaming chat completion chunks. -pub struct Streamer<'a> { +pub struct Streamer { /// The stream of bytes currently being processed stream: Pin> + Send>>, /// Current status of the stream @@ -51,7 +52,7 @@ pub struct Streamer<'a> { /// Estimated total tokens for the stream estimated_total_tokens: i64, /// Keystore - keystore: RwLockReadGuard<'a, Sui>, + sui: Arc>, /// Stack small id stack_small_id: i64, /// State manager sender @@ -85,7 +86,7 @@ pub enum StreamStatus { Failed(String), } -impl<'a> Streamer<'a> { +impl Streamer { /// Creates a new Streamer instance #[allow(clippy::too_many_arguments)] pub fn new( @@ -93,7 +94,7 @@ impl<'a> Streamer<'a> { state_manager_sender: Sender, stack_small_id: i64, estimated_total_tokens: i64, - keystore: RwLockReadGuard<'a, Sui>, + sui: Arc>, start: Instant, node_id: i64, model_name: String, @@ -103,7 +104,7 @@ impl<'a> Streamer<'a> { stream: Box::pin(stream), status: StreamStatus::NotStarted, estimated_total_tokens, - keystore, + sui, stack_small_id, state_manager_sender, start, @@ -274,7 +275,7 @@ impl<'a> Streamer<'a> { } } -impl<'a> Stream for Streamer<'a> { +impl Stream for Streamer { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -461,8 +462,11 @@ impl<'a> Stream for Streamer<'a> { } } else if let Some(usage) = chunk.get(USAGE) { self.status = StreamStatus::Completed; - verify_and_sign_response(&chunk, verify_hash, self.keystore.get_keystore()) - .map_err(|e| Error::new(e.to_string()))?; + let _ = { + let guard = self.sui.blocking_read(); + verify_and_sign_response(&chunk, verify_hash, guard.get_keystore()) + .map_err(|e| Error::new(e.to_string()))? + }; // guard is dropped immediately after signature is created self.handle_final_chunk(usage)?; } From 2fa5bd4ab99b016f6867241181786ee29670a87d Mon Sep 17 00:00:00 2001 From: chad Date: Mon, 10 Feb 2025 13:27:11 -0500 Subject: [PATCH 5/5] fix: use blocking reads for signature verification --- atoma-proxy/src/server/handlers/chat_completions.rs | 3 ++- atoma-proxy/src/server/handlers/embeddings.rs | 4 +++- atoma-proxy/src/server/handlers/image_generations.rs | 3 ++- atoma-proxy/src/server/handlers/mod.rs | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/atoma-proxy/src/server/handlers/chat_completions.rs b/atoma-proxy/src/server/handlers/chat_completions.rs index 7f7a50a..c376413 100644 --- a/atoma-proxy/src/server/handlers/chat_completions.rs +++ b/atoma-proxy/src/server/handlers/chat_completions.rs @@ -489,6 +489,7 @@ pub fn confidential_chat_completions_create_stream( ) )] #[allow(clippy::too_many_arguments)] +#[allow(clippy::significant_drop_tightening)] async fn handle_non_streaming_response( state: &ProxyState, node_address: &String, @@ -554,7 +555,7 @@ async fn handle_non_streaming_response( let verify_hash = endpoint != CONFIDENTIAL_CHAT_COMPLETIONS_PATH; - let guard = state.sui.read().await; + let guard: tokio::sync::RwLockReadGuard<'_, atoma_auth::Sui> = state.sui.blocking_read(); let keystore = guard.get_keystore(); let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?; diff --git a/atoma-proxy/src/server/handlers/embeddings.rs b/atoma-proxy/src/server/handlers/embeddings.rs index ab0de3c..c37b5ae 100644 --- a/atoma-proxy/src/server/handlers/embeddings.rs +++ b/atoma-proxy/src/server/handlers/embeddings.rs @@ -355,6 +355,7 @@ pub async fn confidential_embeddings_create( ) )] #[allow(clippy::too_many_arguments)] +#[allow(clippy::significant_drop_tightening)] async fn handle_embeddings_response( state: &ProxyState, node_address: String, @@ -402,7 +403,8 @@ async fn handle_embeddings_response( message: format!("Failed to parse embeddings response: {err:?}"), endpoint: endpoint.to_string(), })?; - let guard = state.sui.read().await; + + let guard = state.sui.blocking_read(); let keystore = guard.get_keystore(); let verify_hash = endpoint != CONFIDENTIAL_EMBEDDINGS_PATH; diff --git a/atoma-proxy/src/server/handlers/image_generations.rs b/atoma-proxy/src/server/handlers/image_generations.rs index 1c84292..ab6e09c 100644 --- a/atoma-proxy/src/server/handlers/image_generations.rs +++ b/atoma-proxy/src/server/handlers/image_generations.rs @@ -335,6 +335,7 @@ pub async fn confidential_image_generations_create( ) )] #[allow(clippy::too_many_arguments)] +#[allow(clippy::significant_drop_tightening)] async fn handle_image_generation_response( state: &ProxyState, node_address: String, @@ -381,7 +382,7 @@ async fn handle_image_generation_response( }) .map(Json)?; - let guard = state.sui.read().await; + let guard = state.sui.blocking_read(); let keystore = guard.get_keystore(); let verify_hash = endpoint != CONFIDENTIAL_IMAGE_GENERATIONS_PATH; diff --git a/atoma-proxy/src/server/handlers/mod.rs b/atoma-proxy/src/server/handlers/mod.rs index 1ba5497..8d099d8 100644 --- a/atoma-proxy/src/server/handlers/mod.rs +++ b/atoma-proxy/src/server/handlers/mod.rs @@ -233,13 +233,13 @@ pub fn verify_and_sign_response( Keystore::File(keystore) => keystore .sign_hashed(&keystore.addresses()[0], &response_hash_bytes) .map_err(|e| AtomaProxyError::InternalError { - message: format!("Failed to create proxy signature: {}", e), + message: format!("Failed to create proxy signature: {e}"), endpoint: "verify_signature".to_string(), })?, Keystore::InMem(keystore) => keystore .sign_hashed(&keystore.addresses()[0], &response_hash_bytes) .map_err(|e| AtomaProxyError::InternalError { - message: format!("Failed to create proxy signature: {}", e), + message: format!("Failed to create proxy signature: {e}"), endpoint: "verify_signature".to_string(), })?, };