From cae43ab2d0bc6f250bd1b0cc8feb0e0d71a5bfd6 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Thu, 22 Aug 2024 09:45:08 +0200 Subject: [PATCH] Add identity verification --- src/endpoint/mod.rs | 20 +++++++++++++++++++- test-services/src/main.rs | 4 ++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index dfcf28f..160f928 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -9,7 +9,9 @@ use ::futures::future::BoxFuture; use ::futures::{Stream, StreamExt}; use bytes::Bytes; pub use context::{ContextInternal, InputMetadata}; -use restate_sdk_shared_core::{CoreVM, Header, HeaderMap, ResponseHead, VMError, VM}; +use restate_sdk_shared_core::{ + CoreVM, Header, HeaderMap, IdentityVerifier, KeyError, ResponseHead, VMError, VerifyError, VM, +}; use std::collections::HashMap; use std::future::poll_fn; use std::pin::Pin; @@ -88,6 +90,7 @@ impl Error { | ErrorInner::HandlerResult { .. } => 500, ErrorInner::BadDiscovery(_) => 415, ErrorInner::Header { .. } | ErrorInner::BadPath { .. } => 400, + ErrorInner::IdentityVerification(_) => 401, } } } @@ -100,6 +103,8 @@ enum ErrorInner { UnknownServiceHandler(String, String), #[error("Error when processing the request: {0:?}")] VM(#[from] VMError), + #[error("Error when verifying identity: {0:?}")] + IdentityVerification(#[from] VerifyError), #[error("Cannot convert header '{0}', reason: {1}")] Header(String, #[source] BoxError), #[error("Cannot reply to discovery, got accept header '{0}' but currently supported discovery is {DISCOVERY_CONTENT_TYPE}")] @@ -165,6 +170,7 @@ impl Service for BoxedService { pub struct Builder { svcs: HashMap, discovery: crate::discovery::Endpoint, + identity_verifier: IdentityVerifier, } impl Default for Builder { @@ -177,6 +183,7 @@ impl Default for Builder { protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, + identity_verifier: Default::default(), } } } @@ -204,10 +211,16 @@ impl Builder { self } + pub fn with_identity_key(mut self, key: &str) -> Result { + self.identity_verifier = self.identity_verifier.with_key(key)?; + Ok(self) + } + pub fn build(self) -> Endpoint { Endpoint(Arc::new(EndpointInner { svcs: self.svcs, discovery: self.discovery, + identity_verifier: self.identity_verifier, })) } } @@ -224,6 +237,7 @@ impl Endpoint { pub struct EndpointInner { svcs: HashMap, discovery: crate::discovery::Endpoint, + identity_verifier: IdentityVerifier, } impl Endpoint { @@ -232,6 +246,10 @@ impl Endpoint { H: HeaderMap, ::Error: std::error::Error + Send + Sync + 'static, { + if let Err(e) = self.0.identity_verifier.verify_identity(&headers, path) { + return Err(ErrorInner::IdentityVerification(e).into()); + } + let parts: Vec<&str> = path.split('/').collect(); if parts.last() == Some(&"discover") { diff --git a/test-services/src/main.rs b/test-services/src/main.rs index 2c9bc59..a22635e 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -77,6 +77,10 @@ async fn main() { )) } + if let Ok(key) = env::var("E2E_REQUEST_SIGNING_ENV") { + builder = builder.with_identity_key(&key).unwrap() + } + HttpServer::new(builder.build()) .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) .await;