Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assign memo ingredients per salsa-struct-ingredient #614

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions components/salsa-macro-rules/src/setup_input_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ macro_rules! setup_input_struct {
}

impl $zalsa::SalsaStructInDb for $Struct {
fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> {
aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default())
}
}

impl $Struct {
Expand Down
3 changes: 3 additions & 0 deletions components/salsa-macro-rules/src/setup_interned_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ macro_rules! setup_interned_struct {
}

impl $zalsa::SalsaStructInDb for $Struct<'_> {
fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> {
aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default())
}
}

unsafe impl $zalsa::Update for $Struct<'_> {
Expand Down
19 changes: 19 additions & 0 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ macro_rules! setup_tracked_fn {
$zalsa::IngredientCache::new();

impl $zalsa::SalsaStructInDb for $InternedData<'_> {
fn lookup_ingredient_index(_aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> {
None
}
}

impl $zalsa::interned::Configuration for $Configuration {
Expand Down Expand Up @@ -199,7 +202,19 @@ macro_rules! setup_tracked_fn {
aux: &dyn $zalsa::JarAux,
first_index: $zalsa::IngredientIndex,
) -> Vec<Box<dyn $zalsa::Ingredient>> {
let struct_index = $zalsa::macro_if! {
if $needs_interner {
first_index.successor(0)
} else {
<$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(aux)
.expect(
"Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!"
)
}
};

let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new(
struct_index,
first_index,
aux,
);
Expand All @@ -219,6 +234,10 @@ macro_rules! setup_tracked_fn {
}
}
}

fn salsa_struct_type_id(&self) -> Option<core::any::TypeId> {
None
}
}

#[allow(non_local_definitions)]
Expand Down
3 changes: 3 additions & 0 deletions components/salsa-macro-rules/src/setup_tracked_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ macro_rules! setup_tracked_struct {
}

impl $zalsa::SalsaStructInDb for $Struct<'_> {
fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> {
aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default())
}
}

impl $zalsa::TrackedStructInDb for $Struct<'_> {
Expand Down
4 changes: 4 additions & 0 deletions src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl<A: Accumulator> Jar for JarImpl<A> {
) -> Vec<Box<dyn Ingredient>> {
vec![Box::new(<IngredientImpl<A>>::new(first_index))]
}

fn salsa_struct_type_id(&self) -> Option<std::any::TypeId> {
None
}
}

pub struct IngredientImpl<A: Accumulator> {
Expand Down
4 changes: 2 additions & 2 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ impl<C> IngredientImpl<C>
where
C: Configuration,
{
pub fn new(index: IngredientIndex, aux: &dyn JarAux) -> Self {
pub fn new(struct_index: IngredientIndex, index: IngredientIndex, aux: &dyn JarAux) -> Self {
Self {
index,
memo_ingredient_index: aux.next_memo_ingredient_index(index),
memo_ingredient_index: aux.next_memo_ingredient_index(struct_index, index),
lru: Default::default(),
deleted_entries: Default::default(),
}
Expand Down
11 changes: 10 additions & 1 deletion src/ingredient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,19 @@ pub trait Jar: Any {
aux: &dyn JarAux,
first_index: IngredientIndex,
) -> Vec<Box<dyn Ingredient>>;

/// If this jar's first ingredient is a salsa struct, return its `TypeId`
fn salsa_struct_type_id(&self) -> Option<TypeId>;
}

pub trait JarAux {
fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex;
fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex>;

fn next_memo_ingredient_index(
&self,
struct_ingredient_index: IngredientIndex,
ingredient_index: IngredientIndex,
) -> MemoIngredientIndex;
}

pub trait Ingredient: Any + std::fmt::Debug + Send + Sync {
Expand Down
10 changes: 9 additions & 1 deletion src/input.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::{any::Any, fmt, ops::DerefMut};
use std::{
any::{Any, TypeId},
fmt,
ops::DerefMut,
};

pub mod input_field;
pub mod setter;
Expand Down Expand Up @@ -60,6 +64,10 @@ impl<C: Configuration> Jar for JarImpl<C> {
}))
.collect()
}

fn salsa_struct_type_id(&self) -> Option<std::any::TypeId> {
Some(TypeId::of::<<C as Configuration>::Struct>())
}
}

pub struct IngredientImpl<C: Configuration> {
Expand Down
5 changes: 5 additions & 0 deletions src/interned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::table::Slot;
use crate::zalsa::IngredientIndex;
use crate::zalsa_local::QueryOrigin;
use crate::{Database, DatabaseKeyIndex, Id};
use std::any::TypeId;
use std::fmt;
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
Expand Down Expand Up @@ -92,6 +93,10 @@ impl<C: Configuration> Jar for JarImpl<C> {
) -> Vec<Box<dyn Ingredient>> {
vec![Box::new(IngredientImpl::<C>::new(first_index)) as _]
}

fn salsa_struct_type_id(&self) -> Option<std::any::TypeId> {
Some(TypeId::of::<<C as Configuration>::Struct<'static>>())
}
}

impl<C> IngredientImpl<C>
Expand Down
6 changes: 5 additions & 1 deletion src/salsa_struct.rs
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
pub trait SalsaStructInDb {}
use crate::{plumbing::JarAux, IngredientIndex};

pub trait SalsaStructInDb {
fn lookup_ingredient_index(aux: &dyn JarAux) -> Option<IngredientIndex>;
}
9 changes: 7 additions & 2 deletions src/tracked_struct.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut};
use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut};

use crossbeam::{atomic::AtomicCell, queue::SegQueue};
use tracked_field::FieldIngredientImpl;
Expand Down Expand Up @@ -112,6 +112,10 @@ impl<C: Configuration> Jar for JarImpl<C> {
}))
.collect()
}

fn salsa_struct_type_id(&self) -> Option<TypeId> {
Some(TypeId::of::<<C as Configuration>::Struct<'static>>())
}
}

pub trait TrackedStructInDb: SalsaStructInDb {
Expand Down Expand Up @@ -501,7 +505,8 @@ where
// and the code that references the memo-table has a read-lock.
let memo_table = unsafe { (*data).take_memo_table() };
for (memo_ingredient_index, memo) in memo_table.into_memos() {
let ingredient_index = zalsa.ingredient_index_for_memo(memo_ingredient_index);
let ingredient_index =
zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index);

let executor = DatabaseKeyIndex {
ingredient_index,
Expand Down
60 changes: 41 additions & 19 deletions src/zalsa.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use append_only_vec::AppendOnlyVec;
use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};
use rustc_hash::FxHashMap;
use std::any::{Any, TypeId};
use std::marker::PhantomData;
Expand Down Expand Up @@ -119,8 +119,10 @@ pub struct Zalsa {

nonce: Nonce<StorageNonce>,

/// Number of memo ingredient indices created by calls to [`next_memo_ingredient_index`](`Self::next_memo_ingredient_index`)
memo_ingredients: Mutex<Vec<IngredientIndex>>,
/// Map from the [`IngredientIndex::as_usize`][] of a salsa struct to a list of
/// [ingredient-indices](`IngredientIndex`) for tracked functions that have this salsa struct
/// as input.
memo_ingredient_indices: RwLock<Vec<Vec<IngredientIndex>>>,

/// Map from the type-id of an `impl Jar` to the index of its first ingredient.
/// This is using a `Mutex<FxHashMap>` (versus, say, a `FxDashMap`)
Expand Down Expand Up @@ -152,7 +154,7 @@ impl Zalsa {
ingredients_vec: AppendOnlyVec::new(),
ingredients_requiring_reset: AppendOnlyVec::new(),
runtime: Runtime::default(),
memo_ingredients: Default::default(),
memo_ingredient_indices: Default::default(),
}
}

Expand Down Expand Up @@ -186,21 +188,22 @@ impl Zalsa {
{
let jar_type_id = jar.type_id();
let mut jar_map = self.jar_map.lock();
*jar_map
.entry(jar_type_id)
.or_insert_with(|| {
let index = IngredientIndex::from(self.ingredients_vec.len());
let ingredients = jar.create_ingredients(self, index);
let mut should_create = false;
let index = *jar_map.entry(jar_type_id).or_insert_with(|| {
should_create = true;
IngredientIndex::from(self.ingredients_vec.len())
});
if should_create {
let aux = JarAuxImpl(self, &jar_map);
let ingredients = jar.create_ingredients(&aux, index);
for ingredient in ingredients {
let expected_index = ingredient.ingredient_index();

if ingredient.requires_reset_for_new_revision() {
self.ingredients_requiring_reset.push(expected_index);
}

let actual_index = self
.ingredients_vec
.push(ingredient);
let actual_index = self.ingredients_vec.push(ingredient);
assert_eq!(
expected_index.as_usize(),
actual_index,
Expand All @@ -209,10 +212,10 @@ impl Zalsa {
expected_index,
actual_index,
);

}
index
})
}

index
}
}

Expand Down Expand Up @@ -290,15 +293,34 @@ impl Zalsa {

pub(crate) fn ingredient_index_for_memo(
&self,
struct_ingredient_index: IngredientIndex,
memo_ingredient_index: MemoIngredientIndex,
) -> IngredientIndex {
self.memo_ingredients.lock()[memo_ingredient_index.as_usize()]
self.memo_ingredient_indices.read()[struct_ingredient_index.as_usize()]
[memo_ingredient_index.as_usize()]
}
}

impl JarAux for Zalsa {
fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex {
let mut memo_ingredients = self.memo_ingredients.lock();
struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap<TypeId, IngredientIndex>);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed salsa struct ingredient lookup as I said here but I had to change this to avoid deadlocks 😢


impl JarAux for JarAuxImpl<'_> {
fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option<IngredientIndex> {
self.1.get(&jar.type_id()).map(ToOwned::to_owned)
}

fn next_memo_ingredient_index(
&self,
struct_ingredient_index: IngredientIndex,
ingredient_index: IngredientIndex,
) -> MemoIngredientIndex {
let mut memo_ingredients = self.0.memo_ingredient_indices.write();
let idx = struct_ingredient_index.as_usize();
let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) {
memo_ingredients
} else {
memo_ingredients.resize_with(idx + 1, Vec::new);
memo_ingredients.get_mut(idx).unwrap()
};
let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap());
memo_ingredients.push(ingredient_index);
mi
Expand Down
25 changes: 25 additions & 0 deletions tests/tracked_fn_multiple_args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//! Test that a `tracked` fn on multiple salsa struct args
//! compiles and executes successfully.

#[salsa::input]
struct MyInput {
field: u32,
}

#[salsa::interned]
struct MyInterned<'db> {
field: u32,
}

#[salsa::tracked]
fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput, interned: MyInterned<'db>) -> u32 {
input.field(db) + interned.field(db)
}

#[test]
fn execute() {
let db = salsa::DatabaseImpl::new();
let input = MyInput::new(&db, 22);
let interned = MyInterned::new(&db, 33);
assert_eq!(tracked_fn(&db, input, interned), 55);
}
Loading