Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add signature verification to node response #105

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
license = "Apache-2.0"

[workspace.dependencies]

anyhow = "1.0.91"
async-trait = "0.1.86"
atoma-auth = { path = "./atoma-auth" }
Expand All @@ -24,8 +25,8 @@ chrono = "0.4.38"
clap = "4.5.20"
config = "0.14.1"
dcap-qvl = "0.1.6"
fastcrypto = { git = "https://github.com/MystenLabs/fastcrypto", rev="69d496c71fb37e3d22fe85e5bbfd4256d61422b9", package = "fastcrypto" }
fastcrypto-zkp = { git = "https://github.com/MystenLabs/fastcrypto", rev="69d496c71fb37e3d22fe85e5bbfd4256d61422b9", package = "fastcrypto-zkp" }
fastcrypto = { git = "https://github.com/MystenLabs/fastcrypto", rev = "69d496c71fb37e3d22fe85e5bbfd4256d61422b9", package = "fastcrypto" }
fastcrypto-zkp = { git = "https://github.com/MystenLabs/fastcrypto", rev = "69d496c71fb37e3d22fe85e5bbfd4256d61422b9", package = "fastcrypto-zkp" }
flume = "0.11.1"
futures = "0.3.31"
hex = "0.4.3"
Expand All @@ -43,7 +44,10 @@ serde_json = "1.0.138"
serde_yaml = "0.9.34"
serial_test = "3.1.1"
shared-crypto = { git = "https://github.com/mystenlabs/sui", package = "shared-crypto", tag = "testnet-v1.42.1" }
sqlx = { version = "0.8.2", features = ["postgres","runtime-tokio-native-tls"] }
sqlx = { version = "0.8.2", features = [
"postgres",
"runtime-tokio-native-tls",
] }
sui-keys = { git = "https://github.com/mystenlabs/sui", package = "sui-keys", tag = "testnet-v1.42.1" }
sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk", tag = "testnet-v1.42.1" }
sui-sdk-types = "0.0.2"
Expand Down
10 changes: 10 additions & 0 deletions atoma-auth/src/sui/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ impl Sui {
Ok(signature.encode_base64())
}

/// Get the underlying keystore
///
/// # Returns
///
/// Returns the keystore.
#[must_use]
pub fn get_keystore(&self) -> &Keystore {
&self.wallet_ctx.config.keystore
}

/// Sign a hash using the wallet's private key
///
/// # Arguments
Expand Down
13 changes: 10 additions & 3 deletions atoma-proxy/src/server/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use super::metrics::{
};
use super::request_model::RequestModel;
use super::{
handle_status_code_error, update_state_manager, verify_response_hash_and_signature,
handle_status_code_error, update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY,
RESPONSE_HASH_KEY,
};
use crate::server::{Result, DEFAULT_MAX_TOKENS, MAX_COMPLETION_TOKENS, MAX_TOKENS, MODEL};
Expand Down Expand Up @@ -489,6 +489,7 @@ pub fn confidential_chat_completions_create_stream(
)
)]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::significant_drop_tightening)]
async fn handle_non_streaming_response(
state: &ProxyState,
node_address: &String,
Expand Down Expand Up @@ -524,7 +525,7 @@ async fn handle_non_streaming_response(
handle_status_code_error(response.status(), &endpoint, error)?;
}

let response = response
let mut response = response
.json::<Value>()
.await
.map_err(|err| AtomaProxyError::InternalError {
Expand Down Expand Up @@ -553,7 +554,12 @@ async fn handle_non_streaming_response(
.map_or(0, |n| n as i64);

let verify_hash = endpoint != CONFIDENTIAL_CHAT_COMPLETIONS_PATH;
verify_response_hash_and_signature(&response.0, verify_hash)?;

let guard: tokio::sync::RwLockReadGuard<'_, atoma_auth::Sui> = state.sui.blocking_read();
let keystore = guard.get_keystore();
let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?;

response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature);

state
.state_manager_sender
Expand Down Expand Up @@ -711,6 +717,7 @@ async fn handle_streaming_response(
state.state_manager_sender.clone(),
selected_stack_small_id,
estimated_total_tokens,
state.sui.clone(),
start,
node_id,
model_name,
Expand Down
12 changes: 9 additions & 3 deletions atoma-proxy/src/server/handlers/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use super::{
TOTAL_FAILED_REQUESTS, TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS,
},
request_model::RequestModel,
update_state_manager, verify_response_hash_and_signature, RESPONSE_HASH_KEY,
update_state_manager, verify_and_sign_response, PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY,
};
use crate::server::Result;

Expand Down Expand Up @@ -355,6 +355,7 @@ pub async fn confidential_embeddings_create(
)
)]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::significant_drop_tightening)]
async fn handle_embeddings_response(
state: &ProxyState,
node_address: String,
Expand Down Expand Up @@ -394,7 +395,7 @@ async fn handle_embeddings_response(
handle_status_code_error(response.status(), &endpoint, error)?;
}

let response =
let mut response =
response
.json::<Value>()
.await
Expand All @@ -403,8 +404,13 @@ async fn handle_embeddings_response(
endpoint: endpoint.to_string(),
})?;

let guard = state.sui.blocking_read();
let keystore = guard.get_keystore();
let verify_hash = endpoint != CONFIDENTIAL_EMBEDDINGS_PATH;
verify_response_hash_and_signature(&response, verify_hash)?;

let proxy_signature = verify_and_sign_response(&response, verify_hash, keystore)?;

response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature);

let num_input_compute_units = if endpoint == CONFIDENTIAL_EMBEDDINGS_PATH {
response
Expand Down
17 changes: 13 additions & 4 deletions atoma-proxy/src/server/handlers/image_generations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ use crate::server::error::AtomaProxyError;
use crate::server::types::{ConfidentialComputeRequest, ConfidentialComputeResponse};
use crate::server::{http_server::ProxyState, middleware::RequestMetadataExtension};

use super::handle_status_code_error;
use super::metrics::{
IMAGE_GEN_LATENCY_METRICS, IMAGE_GEN_NUM_REQUESTS, TOTAL_COMPLETED_REQUESTS,
TOTAL_FAILED_IMAGE_GENERATION_REQUESTS, TOTAL_FAILED_REQUESTS,
};
use super::{handle_status_code_error, verify_response_hash_and_signature};
use super::{request_model::RequestModel, update_state_manager, RESPONSE_HASH_KEY};
use super::{
request_model::RequestModel, update_state_manager, verify_and_sign_response,
PROXY_SIGNATURE_KEY, RESPONSE_HASH_KEY,
};
use crate::server::{Result, MODEL};

/// Path for the confidential image generations endpoint.
Expand Down Expand Up @@ -332,6 +335,7 @@ pub async fn confidential_image_generations_create(
)
)]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::significant_drop_tightening)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've allowed this as refactoring would be quite substantial, see #210

async fn handle_image_generation_response(
state: &ProxyState,
node_address: String,
Expand Down Expand Up @@ -369,7 +373,7 @@ async fn handle_image_generation_response(
handle_status_code_error(response.status(), &endpoint, error)?;
}

let response = response
let mut response = response
.json::<Value>()
.await
.map_err(|err| AtomaProxyError::InternalError {
Expand All @@ -378,8 +382,13 @@ async fn handle_image_generation_response(
})
.map(Json)?;

let guard = state.sui.blocking_read();
let keystore = guard.get_keystore();
let verify_hash = endpoint != CONFIDENTIAL_IMAGE_GENERATIONS_PATH;
verify_response_hash_and_signature(&response.0, verify_hash)?;

let proxy_signature = verify_and_sign_response(&response.0, verify_hash, keystore)?;

response[PROXY_SIGNATURE_KEY] = Value::String(proxy_signature);

// Update the node throughput performance
state
Expand Down
56 changes: 44 additions & 12 deletions atoma-proxy/src/server/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::str::FromStr;

use atoma_state::types::AtomaAtomaStateManagerEvent;
use base64::engine::{general_purpose::STANDARD, Engine};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use blake2::Digest;
use fastcrypto::{
ed25519::{Ed25519PublicKey, Ed25519Signature},
Expand All @@ -11,6 +11,7 @@ use fastcrypto::{
};
use flume::Sender;
use reqwest::StatusCode;
use sui_keys::keystore::{AccountKeystore, Keystore};
use sui_sdk::types::crypto::{PublicKey, Signature, SignatureScheme, SuiSignature};
use tracing::instrument;

Expand All @@ -31,6 +32,9 @@ pub const RESPONSE_HASH_KEY: &str = "response_hash";
/// Key for the signature in the payload
pub const SIGNATURE_KEY: &str = "signature";

/// Key for the proxy signature in the payload
pub const PROXY_SIGNATURE_KEY: &str = "proxy_signature";

/// Updates the state manager with token usage and hash information for a stack.
///
/// This function performs two main operations:
Expand Down Expand Up @@ -81,42 +85,54 @@ pub fn update_state_manager(
Ok(())
}

/// Verifies a Sui signature from a handler response
/// Verifies a Sui signature and creates a new signature using the proxy's key
///
/// # Arguments
///
/// * `payload` - JSON payload containing the response hash and its signature
/// * `node_public_key` - Public key of the node that signed the response
/// * `proxy_keystore` - Keystore containing the proxy's signing key
///
/// # Returns
///
/// Returns `Ok(())` if verification succeeds,
/// Returns `Ok(String)` with the new signature if verification succeeds,
/// or an error if verification fails or signing fails
///
/// # Errors
///
/// This function will return an error if:
/// - The payload format is invalid
/// - The signature verification fails
/// - Creating the new signature fails
#[instrument(level = "debug", skip_all)]
pub fn verify_response_hash_and_signature(
pub fn verify_and_sign_response(
payload: &serde_json::Value,
verify_hash: bool,
) -> Result<()> {
keystore: &Keystore,
) -> Result<String> {
// Extract response hash and signature from payload
let response_hash =
let response_hash_str =
payload[RESPONSE_HASH_KEY]
.as_str()
.ok_or_else(|| AtomaProxyError::InternalError {
message: "Missing response_hash in payload".to_string(),
endpoint: "verify_signature".to_string(),
})?;
let response_hash = STANDARD.decode(response_hash).unwrap();

if verify_hash {
verify_response_hash(payload, &response_hash)?;
// Decode base64 string to bytes for verification
let response_hash_bytes =
BASE64
.decode(response_hash_str)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to decode response hash: {e}"),
endpoint: "verify_signature".to_string(),
})?;
verify_response_hash(payload, &response_hash_bytes)?;
}

let response_hash_bytes = BASE64.decode(response_hash_str).unwrap();

let node_signature =
payload[SIGNATURE_KEY]
.as_str()
Expand Down Expand Up @@ -156,7 +172,7 @@ pub fn verify_response_hash_and_signature(
}
})?;
public_key
.verify(response_hash.as_slice(), &signature)
.verify(response_hash_bytes.as_slice(), &signature)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to verify ed25519 signature: {e}"),
endpoint: "verify_signature".to_string(),
Expand All @@ -177,7 +193,7 @@ pub fn verify_response_hash_and_signature(
}
})?;
public_key
.verify(response_hash.as_slice(), &signature)
.verify(response_hash_bytes.as_slice(), &signature)
.map_err(|_| AtomaProxyError::InternalError {
message: "Failed to verify secp256k1 signature".to_string(),
endpoint: "verify_signature".to_string(),
Expand All @@ -198,7 +214,7 @@ pub fn verify_response_hash_and_signature(
}
})?;
public_key
.verify(response_hash.as_slice(), &signature)
.verify(response_hash_bytes.as_slice(), &signature)
.map_err(|_| AtomaProxyError::InternalError {
message: "Failed to verify secp256r1 signature".to_string(),
endpoint: "verify_signature".to_string(),
Expand All @@ -212,7 +228,23 @@ pub fn verify_response_hash_and_signature(
}
}

Ok(())
// Sign with proxy's key
let proxy_signature = match keystore {
Keystore::File(keystore) => keystore
.sign_hashed(&keystore.addresses()[0], &response_hash_bytes)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to create proxy signature: {e}"),
endpoint: "verify_signature".to_string(),
})?,
Keystore::InMem(keystore) => keystore
.sign_hashed(&keystore.addresses()[0], &response_hash_bytes)
.map_err(|e| AtomaProxyError::InternalError {
message: format!("Failed to create proxy signature: {e}"),
endpoint: "verify_signature".to_string(),
})?,
};
// Convert signature to base64
Ok(BASE64.encode(proxy_signature.as_ref()))
}

/// Verifies that a response hash matches the computed hash of the payload
Expand Down
Loading
Loading