Skip to content

Commit

Permalink
fix: thread-safe extensions
Browse files Browse the repository at this point in the history
Ensure that Extensions are Send by replacing shared references with
thread-safe alternatives.

Signed-off-by: Sergei Trofimov <sergei.trofimov@arm.com>
  • Loading branch information
setrofim committed Oct 31, 2024
1 parent d0e4d2d commit 406a052
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
// SPDX-License-Identifier: Apache-2.0

use std::cell::RefCell;
use std::collections::{BTreeMap, HashSet};
use std::fmt;
use std::rc::Rc;
use std::sync::Mutex;
use std::sync::{Arc, Mutex, RwLock};

use lazy_static::lazy_static;
use serde::de::{Error as _, MapAccess, SeqAccess, Visitor};
Expand Down Expand Up @@ -231,8 +229,8 @@ enum CollectedKey {

#[derive(Debug)]
pub struct Extensions {
by_key: BTreeMap<i32, Rc<RefCell<ExtensionEntry>>>,
by_name: BTreeMap<String, Rc<RefCell<ExtensionEntry>>>,
by_key: BTreeMap<i32, Arc<RwLock<ExtensionEntry>>>,
by_name: BTreeMap<String, Arc<RwLock<ExtensionEntry>>>,
collected: BTreeMap<CollectedKey, ExtensionValue>,
}

Expand Down Expand Up @@ -264,7 +262,7 @@ impl<'de> Extensions {
));
}

let entry = Rc::new(RefCell::new(ExtensionEntry::new(kind)));
let entry = Arc::new(RwLock::new(ExtensionEntry::new(kind)));

// Check whether any of the values we previously collected match the key or name for
// this entry. If so, add the value to the entry, ensuring it is the right kind.
Expand All @@ -280,20 +278,20 @@ impl<'de> Extensions {
.or(self.collected.get(&CollectedKey::Name(name.to_string())));
match collected {
Some(v) => {
let entry_kind = &entry.borrow().kind.clone();
let entry_kind = &entry.read().unwrap().kind.clone();

if v.is(entry_kind) {
entry.borrow_mut().value = v.clone();
entry.write().unwrap().value = v.clone();
Ok(())
} else if v.can_convert(entry_kind) {
entry.borrow_mut().value = v.convert(entry_kind)?;
entry.write().unwrap().value = v.convert(entry_kind)?;
Ok(())
} else {
Err(Error::ExtensionError(
format!(
"kind mismatch: value is {vk:?}, but want {ek:?}",
vk = v.kind(),
ek = entry.borrow().kind
ek = entry.read().unwrap().kind
)
.to_string(),
))
Expand All @@ -302,8 +300,8 @@ impl<'de> Extensions {
None => Ok(()),
}?;

self.by_key.insert(key, Rc::clone(&entry));
self.by_name.insert(name.to_string(), Rc::clone(&entry));
self.by_key.insert(key, Arc::clone(&entry));
self.by_name.insert(name.to_string(), Arc::clone(&entry));

Ok(())
}
Expand All @@ -319,25 +317,25 @@ impl<'de> Extensions {
pub fn get_by_key(&self, key: &i32) -> Option<ExtensionValue> {
self.by_key
.get(key)
.map(|entry| entry.borrow().value.clone())
.map(|entry| entry.read().unwrap().value.clone())
}

pub fn get_by_name(&self, name: &str) -> Option<ExtensionValue> {
self.by_name
.get(name)
.map(|entry| entry.borrow().value.clone())
.map(|entry| entry.read().unwrap().value.clone())
}

pub fn get_kind_by_key(&self, key: &i32) -> ExtensionKind {
match self.by_key.get(key) {
Some(entry) => entry.borrow().kind.clone(),
Some(entry) => entry.read().unwrap().kind.clone(),
None => ExtensionKind::Unset,
}
}

pub fn get_kind_by_name(&self, name: &str) -> ExtensionKind {
match self.by_name.get(name) {
Some(entry) => entry.borrow().kind.clone(),
Some(entry) => entry.read().unwrap().kind.clone(),
None => ExtensionKind::Unset,
}
}
Expand All @@ -347,15 +345,15 @@ impl<'de> Extensions {
format!("{key} not registered").to_string(),
))?;

if !value.is(&entry.borrow().kind) {
if !value.is(&entry.read().unwrap().kind) {
return Err(Error::ExtensionError(format!(
"kind mismatch: value is {vk:?}, but want {ek:?}",
vk = value.kind(),
ek = entry.borrow().kind
ek = entry.read().unwrap().kind
)));
}

entry.borrow_mut().value = value;
entry.write().unwrap().value = value;

Ok(())
}
Expand All @@ -365,15 +363,15 @@ impl<'de> Extensions {
format!("{name} not registered").to_string(),
))?;

if !value.is(&entry.borrow().kind) {
if !value.is(&entry.read().unwrap().kind) {
return Err(Error::ExtensionError(format!(
"kind mismatch: value is {vk:?}, but want {ek:?}",
vk = value.kind(),
ek = entry.borrow().kind
ek = entry.read().unwrap().kind
)));
}

entry.borrow_mut().value = value;
entry.write().unwrap().value = value;

Ok(())
}
Expand Down Expand Up @@ -449,11 +447,11 @@ impl<'de> Extensions {
M: serde::ser::SerializeMap,
{
for (name, val) in &self.by_name {
if val.borrow().value.is(&ExtensionKind::Unset) {
if val.read().unwrap().value.is(&ExtensionKind::Unset) {
continue;
}

map.serialize_entry(&name, &val.borrow().value)?;
map.serialize_entry(&name, &val.read().unwrap().value)?;
}

Ok(())
Expand All @@ -464,11 +462,11 @@ impl<'de> Extensions {
M: serde::ser::SerializeMap,
{
for (key, val) in &self.by_key {
if val.borrow().value.is(&ExtensionKind::Unset) {
if val.read().unwrap().value.is(&ExtensionKind::Unset) {
continue;
}

map.serialize_entry(&key, &val.borrow().value)?;
map.serialize_entry(&key, &val.read().unwrap().value)?;
}

Ok(())
Expand All @@ -480,7 +478,7 @@ impl PartialEq for Extensions {
for (name, val) in &self.by_name {
match other.get_by_name(name) {
Some(other_val) => {
if val.borrow().value != other_val {
if val.read().unwrap().value != other_val {
return false;
}
}
Expand All @@ -491,7 +489,7 @@ impl PartialEq for Extensions {
for (key, val) in &self.by_key {
match other.get_by_key(key) {
Some(other_val) => {
if val.borrow().value != other_val {
if val.read().unwrap().value != other_val {
return false;
}
}
Expand Down Expand Up @@ -662,6 +660,7 @@ mod test {
use crate::error::Error;

use std::str;
use std::thread;

use serde::ser::SerializeMap;
use serde::ser::Serializer;
Expand Down Expand Up @@ -747,4 +746,23 @@ mod test {
panic!("wrong variant: {res:?}");
}
}

#[test]
fn test_send() {
let mut exts = Extensions::new();
exts.register("foo", 1, ExtensionKind::String).unwrap();
exts.set_by_name("foo", ExtensionValue::String("test".to_string()))
.unwrap();

let handle = thread::spawn(move || {
let val = match exts.get_by_name("foo").unwrap() {
ExtensionValue::String(v) => v,
_ => panic!(),
};

assert_eq!(&val, "test");
});

handle.join().unwrap();
}
}

0 comments on commit 406a052

Please sign in to comment.