Skip to content

Commit

Permalink
Merge pull request #2976 from fibonacci1729/fix-2974
Browse files Browse the repository at this point in the history
fix #2974 -- use redis::aio::ConnectionManager
  • Loading branch information
fibonacci1729 authored Jan 10, 2025
2 parents 26ff330 + ec11ba2 commit e8edf4d
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 64 deletions.
38 changes: 36 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 13 additions & 6 deletions crates/factor-key-value/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub trait StoreManager: Sync + Send {

#[async_trait]
pub trait Store: Sync + Send {
async fn after_open(&self) -> Result<(), Error> {
Ok(())
}
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error>;
async fn delete(&self, key: &str) -> Result<(), Error>;
Expand Down Expand Up @@ -109,11 +112,13 @@ impl key_value::HostStore for KeyValueDispatch {
async fn open(&mut self, name: String) -> Result<Result<Resource<key_value::Store>, Error>> {
Ok(async {
if self.allowed_stores.contains(&name) {
let store = self
let store = self.manager.get(&name).await?;
store.after_open().await?;
let store_idx = self
.stores
.push(self.manager.get(&name).await?)
.push(store)
.map_err(|()| Error::StoreTableFull)?;
Ok(Resource::new_own(store))
Ok(Resource::new_own(store_idx))
} else {
Err(Error::AccessDenied)
}
Expand Down Expand Up @@ -193,11 +198,13 @@ impl wasi_keyvalue::store::Host for KeyValueDispatch {
identifier: String,
) -> Result<Resource<wasi_keyvalue::store::Bucket>, wasi_keyvalue::store::Error> {
if self.allowed_stores.contains(&identifier) {
let store = self
let store = self.manager.get(&identifier).await.map_err(to_wasi_err)?;
store.after_open().await.map_err(to_wasi_err)?;
let store_idx = self
.stores
.push(self.manager.get(&identifier).await.map_err(to_wasi_err)?)
.push(store)
.map_err(|()| wasi_keyvalue::store::Error::Other("store table full".to_string()))?;
Ok(Resource::new_own(store))
Ok(Resource::new_own(store_idx))
} else {
Err(wasi_keyvalue::store::Error::AccessDenied)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/factor-key-value/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ struct CachingStore {

#[async_trait]
impl Store for CachingStore {
async fn after_open(&self) -> Result<(), Error> {
self.inner.after_open().await
}

async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
// Retrieve the specified value from the cache, lazily populating the cache as necessary.

Expand Down
2 changes: 1 addition & 1 deletion crates/key-value-redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = { workspace = true }

[dependencies]
anyhow = { workspace = true }
redis = { version = "0.27", features = ["tokio-comp", "tokio-native-tls-comp"] }
redis = { version = "0.28", features = ["tokio-comp", "tokio-native-tls-comp", "connection-manager"] }
serde = { workspace = true }
spin-core = { path = "../core" }
spin-factor-key-value = { path = "../factor-key-value" }
Expand Down
83 changes: 28 additions & 55 deletions crates/key-value-redis/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use anyhow::{Context, Result};
use redis::{aio::MultiplexedConnection, parse_redis_url, AsyncCommands, Client, RedisError};
use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError};
use spin_core::async_trait;
use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::sync::{Mutex, OnceCell};
use tokio::sync::OnceCell;
use url::Url;

pub struct KeyValueRedis {
database_url: Url,
connection: OnceCell<Arc<Mutex<MultiplexedConnection>>>,
connection: OnceCell<ConnectionManager>,
}

impl KeyValueRedis {
Expand All @@ -30,10 +29,8 @@ impl StoreManager for KeyValueRedis {
.connection
.get_or_try_init(|| async {
Client::open(self.database_url.clone())?
.get_multiplexed_async_connection()
.get_connection_manager()
.await
.map(Mutex::new)
.map(Arc::new)
})
.await
.map_err(log_error)?;
Expand All @@ -55,90 +52,69 @@ impl StoreManager for KeyValueRedis {
}

struct RedisStore {
connection: Arc<Mutex<MultiplexedConnection>>,
connection: ConnectionManager,
database_url: Url,
}

struct CompareAndSwap {
key: String,
connection: Arc<Mutex<MultiplexedConnection>>,
connection: ConnectionManager,
bucket_rep: u32,
}

#[async_trait]
impl Store for RedisStore {
async fn after_open(&self) -> Result<(), Error> {
if let Err(_error) = self.connection.clone().ping::<()>().await {
// If an IO error happens, ConnectionManager will start reconnection in the background
// so we do not take any action and just pray re-connection will be successful.
}
Ok(())
}

async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
let mut conn = self.connection.lock().await;
conn.get(key).await.map_err(log_error)
self.connection.clone().get(key).await.map_err(log_error)
}

async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
self.connection
.lock()
.await
.clone()
.set(key, value)
.await
.map_err(log_error)
}

async fn delete(&self, key: &str) -> Result<(), Error> {
self.connection
.lock()
.await
.del(key)
.await
.map_err(log_error)
self.connection.clone().del(key).await.map_err(log_error)
}

async fn exists(&self, key: &str) -> Result<bool, Error> {
self.connection
.lock()
.await
.exists(key)
.await
.map_err(log_error)
self.connection.clone().exists(key).await.map_err(log_error)
}

async fn get_keys(&self) -> Result<Vec<String>, Error> {
self.connection
.lock()
.await
.keys("*")
.await
.map_err(log_error)
self.connection.clone().keys("*").await.map_err(log_error)
}

async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
self.connection
.lock()
.await
.keys(keys)
.await
.map_err(log_error)
self.connection.clone().keys(keys).await.map_err(log_error)
}

async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
self.connection
.lock()
.await
.clone()
.mset(&key_values)
.await
.map_err(log_error)
}

async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
self.connection
.lock()
.await
.del(keys)
.await
.map_err(log_error)
self.connection.clone().del(keys).await.map_err(log_error)
}

async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
self.connection
.lock()
.await
.clone()
.incr(key, delta)
.await
.map_err(log_error)
Expand All @@ -154,10 +130,8 @@ impl Store for RedisStore {
) -> Result<Arc<dyn Cas>, Error> {
let cx = Client::open(self.database_url.clone())
.map_err(log_error)?
.get_multiplexed_async_connection()
.get_connection_manager()
.await
.map(Mutex::new)
.map(Arc::new)
.map_err(log_error)?;

Ok(Arc::new(CompareAndSwap {
Expand All @@ -175,12 +149,11 @@ impl Cas for CompareAndSwap {
async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
redis::cmd("WATCH")
.arg(&self.key)
.exec_async(self.connection.lock().await.deref_mut())
.exec_async(&mut self.connection.clone())
.await
.map_err(log_error)?;
self.connection
.lock()
.await
.clone()
.get(&self.key)
.await
.map_err(log_error)
Expand All @@ -194,12 +167,12 @@ impl Cas for CompareAndSwap {
let res: Result<(), RedisError> = transaction
.atomic()
.set(&self.key, value)
.query_async(self.connection.lock().await.deref_mut())
.query_async(&mut self.connection.clone())
.await;

redis::cmd("UNWATCH")
.arg(&self.key)
.exec_async(self.connection.lock().await.deref_mut())
.exec_async(&mut self.connection.clone())
.await
.map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;

Expand Down

0 comments on commit e8edf4d

Please sign in to comment.