Skip to content

Commit

Permalink
refactor: replace tokio lock with std lock in some sync scenarios (#694)
Browse files Browse the repository at this point in the history
* refactor: replace tokio lock with std lock in some sync scenarios

* fix: clippy warnings
  • Loading branch information
ShiKaiWi authored Mar 3, 2023
1 parent 28402e4 commit a68c84c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 69 deletions.
32 changes: 15 additions & 17 deletions common_util/src/partitioned_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
num::NonZeroUsize,
sync::Arc,
sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard},
};

use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};

/// Simple partitioned `RwLock`
pub struct PartitionedRwLock<T> {
partitions: Vec<Arc<RwLock<T>>>,
Expand All @@ -28,16 +26,16 @@ impl<T> PartitionedRwLock<T> {
}
}

pub async fn read<K: Eq + Hash>(&self, key: &K) -> RwLockReadGuard<'_, T> {
pub fn read<K: Eq + Hash>(&self, key: &K) -> RwLockReadGuard<'_, T> {
let rwlock = self.get_partition(key);

rwlock.read().await
rwlock.read().unwrap()
}

pub async fn write<K: Eq + Hash>(&self, key: &K) -> RwLockWriteGuard<'_, T> {
pub fn write<K: Eq + Hash>(&self, key: &K) -> RwLockWriteGuard<'_, T> {
let rwlock = self.get_partition(key);

rwlock.write().await
rwlock.write().unwrap()
}

fn get_partition<K: Eq + Hash>(&self, key: &K) -> &RwLock<T> {
Expand Down Expand Up @@ -66,10 +64,10 @@ impl<T> PartitionedMutex<T> {
}
}

pub async fn lock<K: Eq + Hash>(&self, key: &K) -> MutexGuard<'_, T> {
pub fn lock<K: Eq + Hash>(&self, key: &K) -> MutexGuard<'_, T> {
let mutex = self.get_partition(key);

mutex.lock().await
mutex.lock().unwrap()
}

fn get_partition<K: Eq + Hash>(&self, key: &K) -> &Mutex<T> {
Expand All @@ -87,37 +85,37 @@ mod tests {

use super::*;

#[tokio::test]
async fn test_partitioned_rwlock() {
#[test]
fn test_partitioned_rwlock() {
let test_locked_map =
PartitionedRwLock::new(HashMap::new(), NonZeroUsize::new(10).unwrap());
let test_key = "test_key".to_string();
let test_value = "test_value".to_string();

{
let mut map = test_locked_map.write(&test_key).await;
let mut map = test_locked_map.write(&test_key);
map.insert(test_key.clone(), test_value.clone());
}

{
let map = test_locked_map.read(&test_key).await;
let map = test_locked_map.read(&test_key);
assert_eq!(map.get(&test_key).unwrap(), &test_value);
}
}

#[tokio::test]
async fn test_partitioned_mutex() {
#[test]
fn test_partitioned_mutex() {
let test_locked_map = PartitionedMutex::new(HashMap::new(), NonZeroUsize::new(10).unwrap());
let test_key = "test_key".to_string();
let test_value = "test_value".to_string();

{
let mut map = test_locked_map.lock(&test_key).await;
let mut map = test_locked_map.lock(&test_key);
map.insert(test_key.clone(), test_value.clone());
}

{
let map = test_locked_map.lock(&test_key).await;
let map = test_locked_map.lock(&test_key);
assert_eq!(map.get(&test_key).unwrap(), &test_value);
}
}
Expand Down
75 changes: 32 additions & 43 deletions components/object_store/src/mem_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ use std::{
hash::{Hash, Hasher},
num::NonZeroUsize,
ops::Range,
sync::Arc,
sync::{Arc, Mutex},
};

use async_trait::async_trait;
use bytes::Bytes;
use clru::{CLruCache, CLruCacheConfig, WeightScale};
use futures::stream::BoxStream;
use snafu::{OptionExt, Snafu};
use tokio::{io::AsyncWrite, sync::Mutex};
use tokio::io::AsyncWrite;
use upstream::{path::Path, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result};

use crate::ObjectStoreRef;
Expand Down Expand Up @@ -52,26 +52,26 @@ impl Partition {
}

impl Partition {
async fn get(&self, key: &str) -> Option<Bytes> {
let mut guard = self.inner.lock().await;
fn get(&self, key: &str) -> Option<Bytes> {
let mut guard = self.inner.lock().unwrap();
guard.get(key).cloned()
}

async fn peek(&self, key: &str) -> Option<Bytes> {
fn peek(&self, key: &str) -> Option<Bytes> {
// FIXME: actually, here write lock is not necessary.
let guard = self.inner.lock().await;
let guard = self.inner.lock().unwrap();
guard.peek(key).cloned()
}

async fn insert(&self, key: String, value: Bytes) {
let mut guard = self.inner.lock().await;
fn insert(&self, key: String, value: Bytes) {
let mut guard = self.inner.lock().unwrap();
// don't care error now.
_ = guard.put_with_weight(key, value);
}

#[cfg(test)]
async fn keys(&self) -> Vec<String> {
let guard = self.inner.lock().await;
fn keys(&self) -> Vec<String> {
let guard = self.inner.lock().unwrap();
guard
.iter()
.map(|(key, _)| key)
Expand Down Expand Up @@ -115,34 +115,31 @@ impl MemCache {
self.partitions[hasher.finish() as usize & self.partition_mask].clone()
}

async fn get(&self, key: &str) -> Option<Bytes> {
fn get(&self, key: &str) -> Option<Bytes> {
let partition = self.locate_partition(key);
partition.get(key).await
partition.get(key)
}

async fn peek(&self, key: &str) -> Option<Bytes> {
fn peek(&self, key: &str) -> Option<Bytes> {
let partition = self.locate_partition(key);
partition.peek(key).await
partition.peek(key)
}

async fn insert(&self, key: String, value: Bytes) {
fn insert(&self, key: String, value: Bytes) {
let partition = self.locate_partition(&key);
partition.insert(key, value).await;
partition.insert(key, value);
}

/// Give a description of the cache state.
#[cfg(test)]
async fn to_string(&self) -> String {
futures::future::join_all(
self.partitions
.iter()
.map(|part| async { part.keys().await.join(",") }),
)
.await
.into_iter()
.enumerate()
.map(|(part_no, keys)| format!("{part_no}: [{keys}]"))
.collect::<Vec<_>>()
.join("\n")
fn state_desc(&self) -> String {
self.partitions
.iter()
.map(|part| part.keys().join(","))
.enumerate()
.map(|(part_no, keys)| format!("{part_no}: [{keys}]"))
.collect::<Vec<_>>()
.join("\n")
}
}

Expand Down Expand Up @@ -195,21 +192,21 @@ impl MemCacheStore {
// TODO(chenxiang): What if there are some overlapping range in cache?
// A request with range [5, 10) can also use [0, 20) cache
let cache_key = Self::cache_key(location, &range);
if let Some(bytes) = self.cache.get(&cache_key).await {
if let Some(bytes) = self.cache.get(&cache_key) {
return Ok(bytes);
}

// TODO(chenxiang): What if two threads reach here? It's better to
// pend one thread, and only let one to fetch data from underlying store.
let bytes = self.underlying_store.get_range(location, range).await?;
self.cache.insert(cache_key, bytes.clone()).await;
self.cache.insert(cache_key, bytes.clone());

Ok(bytes)
}

async fn get_range_with_ro_cache(&self, location: &Path, range: Range<usize>) -> Result<Bytes> {
let cache_key = Self::cache_key(location, &range);
if let Some(bytes) = self.cache.peek(&cache_key).await {
if let Some(bytes) = self.cache.peek(&cache_key) {
return Ok(bytes);
}

Expand Down Expand Up @@ -297,7 +294,7 @@ mod test {

use super::*;

async fn prepare_store(bits: usize, mem_cap: usize) -> MemCacheStore {
fn prepare_store(bits: usize, mem_cap: usize) -> MemCacheStore {
let local_path = tempdir().unwrap();
let local_store = Arc::new(LocalFileSystem::new_with_prefix(local_path.path()).unwrap());

Expand All @@ -309,7 +306,7 @@ mod test {
#[tokio::test]
async fn test_mem_cache_evict() {
// single partition
let store = prepare_store(0, 13).await;
let store = prepare_store(0, 13);

// write date
let location = Path::from("1.sst");
Expand All @@ -324,7 +321,6 @@ mod test {
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());

// get bytes from [5, 10), insert to cache
Expand All @@ -333,12 +329,10 @@ mod test {
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range5_10))
.await
.is_some());

// get bytes from [10, 15), insert to cache
Expand All @@ -351,24 +345,21 @@ mod test {
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_none());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range5_10))
.await
.is_some());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range10_15))
.await
.is_some());
}

#[tokio::test]
async fn test_mem_cache_partition() {
// 4 partitions
let store = prepare_store(2, 100).await;
let store = prepare_store(2, 100);
let location = Path::from("partition.sst");
store
.put(&location, Bytes::from_static(&[1; 1024]))
Expand All @@ -388,18 +379,16 @@ mod test {
1: [partition.sst-100-105]
2: []
3: [partition.sst-0-5]"#,
store.cache.as_ref().to_string().await
store.cache.as_ref().state_desc()
);

assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range100_105))
.await
.is_some());
}
}
9 changes: 4 additions & 5 deletions remote_engine_client/src/cached_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

//! Cached router
use std::collections::HashMap;
use std::{collections::HashMap, sync::RwLock};

use ceresdbproto::storage::{self, RequestContext};
use log::debug;
use router::RouterRef;
use snafu::{OptionExt, ResultExt};
use table_engine::remote::model::TableIdentifier;
use tokio::sync::RwLock;
use tonic::transport::Channel;

use crate::{channel::ChannelPool, config::Config, error::*};
Expand Down Expand Up @@ -40,7 +39,7 @@ impl CachedRouter {
pub async fn route(&self, table_ident: &TableIdentifier) -> Result<Channel> {
// Find in cache first.
let channel_opt = {
let cache = self.cache.read().await;
let cache = self.cache.read().unwrap();
cache.get(table_ident).cloned()
};

Expand All @@ -62,7 +61,7 @@ impl CachedRouter {
let channel = self.do_route(table_ident).await?;

{
let mut cache = self.cache.write().await;
let mut cache = self.cache.write().unwrap();
// Double check here, if still not found, we put it.
let channel_opt = cache.get(table_ident).cloned();
if channel_opt.is_none() {
Expand All @@ -81,7 +80,7 @@ impl CachedRouter {
}

pub async fn evict(&self, table_ident: &TableIdentifier) {
let mut cache = self.cache.write().await;
let mut cache = self.cache.write().unwrap();
let _ = cache.remove(table_ident);
}

Expand Down
7 changes: 3 additions & 4 deletions remote_engine_client/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

//! Channel pool
use std::collections::HashMap;
use std::{collections::HashMap, sync::RwLock};

use router::endpoint::Endpoint;
use snafu::ResultExt;
use tokio::sync::RwLock;
use tonic::transport::{Channel, Endpoint as TonicEndpoint};

use crate::{config::Config, error::*};
Expand All @@ -30,7 +29,7 @@ impl ChannelPool {

pub async fn get(&self, endpoint: &Endpoint) -> Result<Channel> {
{
let inner = self.channels.read().await;
let inner = self.channels.read().unwrap();
if let Some(channel) = inner.get(endpoint) {
return Ok(channel.clone());
}
Expand All @@ -40,7 +39,7 @@ impl ChannelPool {
.builder
.build(endpoint.clone().to_string().as_str())
.await?;
let mut inner = self.channels.write().await;
let mut inner = self.channels.write().unwrap();
// Double check here.
if let Some(channel) = inner.get(endpoint) {
return Ok(channel.clone());
Expand Down

0 comments on commit a68c84c

Please sign in to comment.