From a68c84cdc3ab291cbeeb331716225daa713b210e Mon Sep 17 00:00:00 2001 From: WEI Xikai Date: Fri, 3 Mar 2023 15:37:06 +0800 Subject: [PATCH] refactor: replace tokio lock with std lock in some sync scenarios (#694) * refactor: replace tokio lock with std lock in some sync scenarios * fix: clippy warnings --- common_util/src/partitioned_lock.rs | 32 +++++----- components/object_store/src/mem_cache.rs | 75 ++++++++++------------- remote_engine_client/src/cached_router.rs | 9 ++- remote_engine_client/src/channel.rs | 7 +-- 4 files changed, 54 insertions(+), 69 deletions(-) diff --git a/common_util/src/partitioned_lock.rs b/common_util/src/partitioned_lock.rs index 6f95697f88..ae1f15c4a6 100644 --- a/common_util/src/partitioned_lock.rs +++ b/common_util/src/partitioned_lock.rs @@ -6,11 +6,9 @@ use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, num::NonZeroUsize, - sync::Arc, + sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, }; -use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; - /// Simple partitioned `RwLock` pub struct PartitionedRwLock { partitions: Vec>>, @@ -28,16 +26,16 @@ impl PartitionedRwLock { } } - pub async fn read(&self, key: &K) -> RwLockReadGuard<'_, T> { + pub fn read(&self, key: &K) -> RwLockReadGuard<'_, T> { let rwlock = self.get_partition(key); - rwlock.read().await + rwlock.read().unwrap() } - pub async fn write(&self, key: &K) -> RwLockWriteGuard<'_, T> { + pub fn write(&self, key: &K) -> RwLockWriteGuard<'_, T> { let rwlock = self.get_partition(key); - rwlock.write().await + rwlock.write().unwrap() } fn get_partition(&self, key: &K) -> &RwLock { @@ -66,10 +64,10 @@ impl PartitionedMutex { } } - pub async fn lock(&self, key: &K) -> MutexGuard<'_, T> { + pub fn lock(&self, key: &K) -> MutexGuard<'_, T> { let mutex = self.get_partition(key); - mutex.lock().await + mutex.lock().unwrap() } fn get_partition(&self, key: &K) -> &Mutex { @@ -87,37 +85,37 @@ mod tests { use super::*; - #[tokio::test] - async fn test_partitioned_rwlock() { + #[test] + fn test_partitioned_rwlock() { let test_locked_map = PartitionedRwLock::new(HashMap::new(), NonZeroUsize::new(10).unwrap()); let test_key = "test_key".to_string(); let test_value = "test_value".to_string(); { - let mut map = test_locked_map.write(&test_key).await; + let mut map = test_locked_map.write(&test_key); map.insert(test_key.clone(), test_value.clone()); } { - let map = test_locked_map.read(&test_key).await; + let map = test_locked_map.read(&test_key); assert_eq!(map.get(&test_key).unwrap(), &test_value); } } - #[tokio::test] - async fn test_partitioned_mutex() { + #[test] + fn test_partitioned_mutex() { let test_locked_map = PartitionedMutex::new(HashMap::new(), NonZeroUsize::new(10).unwrap()); let test_key = "test_key".to_string(); let test_value = "test_value".to_string(); { - let mut map = test_locked_map.lock(&test_key).await; + let mut map = test_locked_map.lock(&test_key); map.insert(test_key.clone(), test_value.clone()); } { - let map = test_locked_map.lock(&test_key).await; + let map = test_locked_map.lock(&test_key); assert_eq!(map.get(&test_key).unwrap(), &test_value); } } diff --git a/components/object_store/src/mem_cache.rs b/components/object_store/src/mem_cache.rs index 00fec0fd6f..8cc4a329d1 100644 --- a/components/object_store/src/mem_cache.rs +++ b/components/object_store/src/mem_cache.rs @@ -10,7 +10,7 @@ use std::{ hash::{Hash, Hasher}, num::NonZeroUsize, ops::Range, - sync::Arc, + sync::{Arc, Mutex}, }; use async_trait::async_trait; @@ -18,7 +18,7 @@ use bytes::Bytes; use clru::{CLruCache, CLruCacheConfig, WeightScale}; use futures::stream::BoxStream; use snafu::{OptionExt, Snafu}; -use tokio::{io::AsyncWrite, sync::Mutex}; +use tokio::io::AsyncWrite; use upstream::{path::Path, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result}; use crate::ObjectStoreRef; @@ -52,26 +52,26 @@ impl Partition { } impl Partition { - async fn get(&self, key: &str) -> Option { - let mut guard = self.inner.lock().await; + fn get(&self, key: &str) -> Option { + let mut guard = self.inner.lock().unwrap(); guard.get(key).cloned() } - async fn peek(&self, key: &str) -> Option { + fn peek(&self, key: &str) -> Option { // FIXME: actually, here write lock is not necessary. - let guard = self.inner.lock().await; + let guard = self.inner.lock().unwrap(); guard.peek(key).cloned() } - async fn insert(&self, key: String, value: Bytes) { - let mut guard = self.inner.lock().await; + fn insert(&self, key: String, value: Bytes) { + let mut guard = self.inner.lock().unwrap(); // don't care error now. _ = guard.put_with_weight(key, value); } #[cfg(test)] - async fn keys(&self) -> Vec { - let guard = self.inner.lock().await; + fn keys(&self) -> Vec { + let guard = self.inner.lock().unwrap(); guard .iter() .map(|(key, _)| key) @@ -115,34 +115,31 @@ impl MemCache { self.partitions[hasher.finish() as usize & self.partition_mask].clone() } - async fn get(&self, key: &str) -> Option { + fn get(&self, key: &str) -> Option { let partition = self.locate_partition(key); - partition.get(key).await + partition.get(key) } - async fn peek(&self, key: &str) -> Option { + fn peek(&self, key: &str) -> Option { let partition = self.locate_partition(key); - partition.peek(key).await + partition.peek(key) } - async fn insert(&self, key: String, value: Bytes) { + fn insert(&self, key: String, value: Bytes) { let partition = self.locate_partition(&key); - partition.insert(key, value).await; + partition.insert(key, value); } + /// Give a description of the cache state. #[cfg(test)] - async fn to_string(&self) -> String { - futures::future::join_all( - self.partitions - .iter() - .map(|part| async { part.keys().await.join(",") }), - ) - .await - .into_iter() - .enumerate() - .map(|(part_no, keys)| format!("{part_no}: [{keys}]")) - .collect::>() - .join("\n") + fn state_desc(&self) -> String { + self.partitions + .iter() + .map(|part| part.keys().join(",")) + .enumerate() + .map(|(part_no, keys)| format!("{part_no}: [{keys}]")) + .collect::>() + .join("\n") } } @@ -195,21 +192,21 @@ impl MemCacheStore { // TODO(chenxiang): What if there are some overlapping range in cache? // A request with range [5, 10) can also use [0, 20) cache let cache_key = Self::cache_key(location, &range); - if let Some(bytes) = self.cache.get(&cache_key).await { + if let Some(bytes) = self.cache.get(&cache_key) { return Ok(bytes); } // TODO(chenxiang): What if two threads reach here? It's better to // pend one thread, and only let one to fetch data from underlying store. let bytes = self.underlying_store.get_range(location, range).await?; - self.cache.insert(cache_key, bytes.clone()).await; + self.cache.insert(cache_key, bytes.clone()); Ok(bytes) } async fn get_range_with_ro_cache(&self, location: &Path, range: Range) -> Result { let cache_key = Self::cache_key(location, &range); - if let Some(bytes) = self.cache.peek(&cache_key).await { + if let Some(bytes) = self.cache.peek(&cache_key) { return Ok(bytes); } @@ -297,7 +294,7 @@ mod test { use super::*; - async fn prepare_store(bits: usize, mem_cap: usize) -> MemCacheStore { + fn prepare_store(bits: usize, mem_cap: usize) -> MemCacheStore { let local_path = tempdir().unwrap(); let local_store = Arc::new(LocalFileSystem::new_with_prefix(local_path.path()).unwrap()); @@ -309,7 +306,7 @@ mod test { #[tokio::test] async fn test_mem_cache_evict() { // single partition - let store = prepare_store(0, 13).await; + let store = prepare_store(0, 13); // write date let location = Path::from("1.sst"); @@ -324,7 +321,6 @@ mod test { assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range0_5)) - .await .is_some()); // get bytes from [5, 10), insert to cache @@ -333,12 +329,10 @@ mod test { assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range0_5)) - .await .is_some()); assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range5_10)) - .await .is_some()); // get bytes from [10, 15), insert to cache @@ -351,24 +345,21 @@ mod test { assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range0_5)) - .await .is_none()); assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range5_10)) - .await .is_some()); assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range10_15)) - .await .is_some()); } #[tokio::test] async fn test_mem_cache_partition() { // 4 partitions - let store = prepare_store(2, 100).await; + let store = prepare_store(2, 100); let location = Path::from("partition.sst"); store .put(&location, Bytes::from_static(&[1; 1024])) @@ -388,18 +379,16 @@ mod test { 1: [partition.sst-100-105] 2: [] 3: [partition.sst-0-5]"#, - store.cache.as_ref().to_string().await + store.cache.as_ref().state_desc() ); assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range0_5)) - .await .is_some()); assert!(store .cache .get(&MemCacheStore::cache_key(&location, &range100_105)) - .await .is_some()); } } diff --git a/remote_engine_client/src/cached_router.rs b/remote_engine_client/src/cached_router.rs index bfb6cab176..860b252d32 100644 --- a/remote_engine_client/src/cached_router.rs +++ b/remote_engine_client/src/cached_router.rs @@ -2,14 +2,13 @@ //! Cached router -use std::collections::HashMap; +use std::{collections::HashMap, sync::RwLock}; use ceresdbproto::storage::{self, RequestContext}; use log::debug; use router::RouterRef; use snafu::{OptionExt, ResultExt}; use table_engine::remote::model::TableIdentifier; -use tokio::sync::RwLock; use tonic::transport::Channel; use crate::{channel::ChannelPool, config::Config, error::*}; @@ -40,7 +39,7 @@ impl CachedRouter { pub async fn route(&self, table_ident: &TableIdentifier) -> Result { // Find in cache first. let channel_opt = { - let cache = self.cache.read().await; + let cache = self.cache.read().unwrap(); cache.get(table_ident).cloned() }; @@ -62,7 +61,7 @@ impl CachedRouter { let channel = self.do_route(table_ident).await?; { - let mut cache = self.cache.write().await; + let mut cache = self.cache.write().unwrap(); // Double check here, if still not found, we put it. let channel_opt = cache.get(table_ident).cloned(); if channel_opt.is_none() { @@ -81,7 +80,7 @@ impl CachedRouter { } pub async fn evict(&self, table_ident: &TableIdentifier) { - let mut cache = self.cache.write().await; + let mut cache = self.cache.write().unwrap(); let _ = cache.remove(table_ident); } diff --git a/remote_engine_client/src/channel.rs b/remote_engine_client/src/channel.rs index 5776986d4b..32e3b1054f 100644 --- a/remote_engine_client/src/channel.rs +++ b/remote_engine_client/src/channel.rs @@ -2,11 +2,10 @@ //! Channel pool -use std::collections::HashMap; +use std::{collections::HashMap, sync::RwLock}; use router::endpoint::Endpoint; use snafu::ResultExt; -use tokio::sync::RwLock; use tonic::transport::{Channel, Endpoint as TonicEndpoint}; use crate::{config::Config, error::*}; @@ -30,7 +29,7 @@ impl ChannelPool { pub async fn get(&self, endpoint: &Endpoint) -> Result { { - let inner = self.channels.read().await; + let inner = self.channels.read().unwrap(); if let Some(channel) = inner.get(endpoint) { return Ok(channel.clone()); } @@ -40,7 +39,7 @@ impl ChannelPool { .builder .build(endpoint.clone().to_string().as_str()) .await?; - let mut inner = self.channels.write().await; + let mut inner = self.channels.write().unwrap(); // Double check here. if let Some(channel) = inner.get(endpoint) { return Ok(channel.clone());