diff --git a/src/builder.rs b/src/builder.rs index 7f15cced6..02ad77ea2 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -694,10 +694,7 @@ impl NodeBuilder { let vss_seed_bytes: [u8; 32] = vss_xprv.private_key.secret_bytes(); let vss_store = - VssStore::new(vss_url, store_id, vss_seed_bytes, header_provider).map_err(|e| { - log_error!(logger, "Failed to setup VssStore: {}", e); - BuildError::KVStoreSetupFailed - })?; + VssStore::new(vss_url, store_id, vss_seed_bytes, header_provider, Arc::clone(&runtime)); build_with_store_internal( config, self.chain_data_source_config.as_ref(), diff --git a/src/io/vss_store.rs b/src/io/vss_store.rs index 296eaabe3..e2cfc3c7b 100644 --- a/src/io/vss_store.rs +++ b/src/io/vss_store.rs @@ -6,6 +6,8 @@ // accordance with one or both of these licenses. use crate::io::utils::check_namespace_key_validity; +use crate::runtime::Runtime; + use bitcoin::hashes::{sha256, Hash, HashEngine, Hmac, HmacEngine}; use lightning::io::{self, Error, ErrorKind}; use lightning::util::persist::KVStore; @@ -15,7 +17,6 @@ use rand::RngCore; use std::panic::RefUnwindSafe; use std::sync::Arc; use std::time::Duration; -use tokio::runtime::Runtime; use vss_client::client::VssClient; use vss_client::error::VssError; use vss_client::headers::VssHeaderProvider; @@ -41,7 +42,7 @@ type CustomRetryPolicy = FilteredRetryPolicy< pub struct VssStore { client: VssClient, store_id: String, - runtime: Runtime, + runtime: Arc, storable_builder: StorableBuilder, key_obfuscator: KeyObfuscator, } @@ -49,9 +50,8 @@ pub struct VssStore { impl VssStore { pub(crate) fn new( base_url: String, store_id: String, vss_seed: [u8; 32], - header_provider: Arc, - ) -> io::Result { - let runtime = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + header_provider: Arc, runtime: Arc, + ) -> Self { let (data_encryption_key, obfuscation_master_key) = derive_data_encryption_and_obfuscation_keys(&vss_seed); let key_obfuscator = KeyObfuscator::new(obfuscation_master_key); @@ -70,7 +70,7 @@ impl VssStore { }) as _); let client = VssClient::new_with_headers(base_url, retry_policy, header_provider); - Ok(Self { client, store_id, runtime, storable_builder, key_obfuscator }) + Self { client, store_id, runtime, storable_builder, key_obfuscator } } fn build_key( @@ -136,19 +136,16 @@ impl KVStore for VssStore { store_id: self.store_id.clone(), key: self.build_key(primary_namespace, secondary_namespace, key)?, }; - - let resp = - tokio::task::block_in_place(|| self.runtime.block_on(self.client.get_object(&request))) - .map_err(|e| { - let msg = format!( - "Failed to read from key {}/{}/{}: {}", - primary_namespace, secondary_namespace, key, e - ); - match e { - VssError::NoSuchKeyError(..) => Error::new(ErrorKind::NotFound, msg), - _ => Error::new(ErrorKind::Other, msg), - } - })?; + let resp = self.runtime.block_on(self.client.get_object(&request)).map_err(|e| { + let msg = format!( + "Failed to read from key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + match e { + VssError::NoSuchKeyError(..) => Error::new(ErrorKind::NotFound, msg), + _ => Error::new(ErrorKind::Other, msg), + } + })?; // unwrap safety: resp.value must be always present for a non-erroneous VSS response, otherwise // it is an API-violation which is converted to [`VssError::InternalServerError`] in [`VssClient`] let storable = Storable::decode(&resp.value.unwrap().value[..]).map_err(|e| { @@ -179,14 +176,13 @@ impl KVStore for VssStore { delete_items: vec![], }; - tokio::task::block_in_place(|| self.runtime.block_on(self.client.put_object(&request))) - .map_err(|e| { - let msg = format!( - "Failed to write to key {}/{}/{}: {}", - primary_namespace, secondary_namespace, key, e - ); - Error::new(ErrorKind::Other, msg) - })?; + self.runtime.block_on(self.client.put_object(&request)).map_err(|e| { + let msg = format!( + "Failed to write to key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + Error::new(ErrorKind::Other, msg) + })?; Ok(()) } @@ -204,30 +200,29 @@ impl KVStore for VssStore { }), }; - tokio::task::block_in_place(|| self.runtime.block_on(self.client.delete_object(&request))) - .map_err(|e| { - let msg = format!( - "Failed to delete key {}/{}/{}: {}", - primary_namespace, secondary_namespace, key, e - ); - Error::new(ErrorKind::Other, msg) - })?; + self.runtime.block_on(self.client.delete_object(&request)).map_err(|e| { + let msg = format!( + "Failed to delete key {}/{}/{}: {}", + primary_namespace, secondary_namespace, key, e + ); + Error::new(ErrorKind::Other, msg) + })?; Ok(()) } fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { check_namespace_key_validity(primary_namespace, secondary_namespace, None, "list")?; - let keys = tokio::task::block_in_place(|| { - self.runtime.block_on(self.list_all_keys(primary_namespace, secondary_namespace)) - }) - .map_err(|e| { - let msg = format!( - "Failed to retrieve keys in namespace: {}/{} : {}", - primary_namespace, secondary_namespace, e - ); - Error::new(ErrorKind::Other, msg) - })?; + let keys = self + .runtime + .block_on(self.list_all_keys(primary_namespace, secondary_namespace)) + .map_err(|e| { + let msg = format!( + "Failed to retrieve keys in namespace: {}/{} : {}", + primary_namespace, secondary_namespace, e + ); + Error::new(ErrorKind::Other, msg) + })?; Ok(keys) } @@ -266,10 +261,27 @@ mod tests { use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng, RngCore}; use std::collections::HashMap; + use tokio::runtime; use vss_client::headers::FixedHeaders; #[test] - fn read_write_remove_list_persist() { + fn vss_read_write_remove_list_persist() { + let runtime = Arc::new(Runtime::new().unwrap()); + let vss_base_url = std::env::var("TEST_VSS_BASE_URL").unwrap(); + let mut rng = thread_rng(); + let rand_store_id: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); + let mut vss_seed = [0u8; 32]; + rng.fill_bytes(&mut vss_seed); + let header_provider = Arc::new(FixedHeaders::new(HashMap::new())); + let vss_store = + VssStore::new(vss_base_url, rand_store_id, vss_seed, header_provider, runtime).unwrap(); + + do_read_write_remove_list_persist(&vss_store); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn vss_read_write_remove_list_persist_in_runtime_context() { + let runtime = Arc::new(Runtime::new().unwrap()); let vss_base_url = std::env::var("TEST_VSS_BASE_URL").unwrap(); let mut rng = thread_rng(); let rand_store_id: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); @@ -277,8 +289,9 @@ mod tests { rng.fill_bytes(&mut vss_seed); let header_provider = Arc::new(FixedHeaders::new(HashMap::new())); let vss_store = - VssStore::new(vss_base_url, rand_store_id, vss_seed, header_provider).unwrap(); + VssStore::new(vss_base_url, rand_store_id, vss_seed, header_provider, runtime).unwrap(); do_read_write_remove_list_persist(&vss_store); + drop(vss_store) } }