Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions benches/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::BTreeSet;
use std::iter::IntoIterator;

use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion};
use salsa::{CycleRecoveryAction, Database as Db, Setter};
use salsa::{Database as Db, Setter};

/// A Use of a symbol.
#[salsa::input]
Expand Down Expand Up @@ -78,10 +78,10 @@ fn def_cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
_last_provisional_value: &Type,
value: &Type,
value: Type,
count: u32,
_def: Definition,
) -> CycleRecoveryAction<Type> {
) -> Type {
cycle_recover(value, count)
}

Expand All @@ -93,24 +93,24 @@ fn use_cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
_last_provisional_value: &Type,
value: &Type,
value: Type,
count: u32,
_use: Use,
) -> CycleRecoveryAction<Type> {
) -> Type {
cycle_recover(value, count)
}

fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction<Type> {
match value {
Type::Bottom => CycleRecoveryAction::Iterate,
fn cycle_recover(value: Type, count: u32) -> Type {
match &value {
Type::Bottom => value,
Type::Values(_) => {
if count > 4 {
CycleRecoveryAction::Fallback(Type::Top)
Type::Top
} else {
CycleRecoveryAction::Iterate
value
}
}
Type::Top => CycleRecoveryAction::Iterate,
Type::Top => value,
}
}

Expand Down
14 changes: 6 additions & 8 deletions book/src/cycles.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@ fn query(db: &dyn salsa::Database) -> u32 {
// ...
}

fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction<u32> {
salsa::CycleRecoveryAction::Iterate
fn cycle_fn(_db: &dyn KnobsDatabase, _id: salsa::Id, _last_provisional_value: &u32, value: u32, _count: u32) -> u32 {
value
}

fn cycle_initial(_db: &dyn KnobsDatabase, _id: salsa::Id) -> u32 {
0
}
```

The `cycle_fn` is optional. The default implementation always returns `Iterate`.
The `cycle_fn` is optional. The default implementation always returns the computed `value`.

If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that fixpoint iteration should continue with the given value (which should be a value that will converge quickly).
If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `cycle_initial` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the last provisional value, the newly computed value, and the number of iterations that have occurred so far). The `cycle_fn` can return the `value` parameter to continue iterating with the computed value, or return a different value (a fallback value) to continue iteration with that value instead.

The cycle will iterate until it converges: that is, until two successive iterations produce the same result.

If the `cycle_fn` returns `Fallback`, the cycle will still continue to iterate (using the given value as a new starting point), in order to verify that the fallback value results in a stable converged cycle. It is not permitted to use a fallback value that does not converge, because this would leave the cycle in an unpredictable state, depending on the order of query execution.
The cycle will iterate until it converges: that is, until the value returned by `cycle_fn` equals the value from the previous iteration.

If a cycle iterates more than 200 times, Salsa will panic rather than iterate forever.

Expand All @@ -40,7 +38,7 @@ Consider a two-query cycle where `query_a` calls `query_b`, and `query_b` calls
Fixed-point iteration is a powerful tool, but is also easy to misuse, potentially resulting in infinite iteration. To avoid this, ensure that all queries participating in fixpoint iteration are deterministic and monotone.

To guarantee convergence, you can leverage the `last_provisional_value` (3rd parameter) received by `cycle_fn`.
When the `cycle_fn` recalculates a value, you can implement a strategy that references the last provisional value to "join" values ​​or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values ​​between cycles.
When the `cycle_fn` receives a newly computed value, you can implement a strategy that references the last provisional value to "join" values or "widen" it and return a fallback value. This ensures monotonicity of the calculation and suppresses infinite oscillation of values between cycles. For example:

Also, in fixed-point iteration, it is advantageous to be able to identify which cycle head seeded a value. By embedding a `salsa::Id` (2nd parameter) in the initial value as a "cycle marker", the recovery function can detect self-originated recursion.

Expand Down
4 changes: 2 additions & 2 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ macro_rules! setup_tracked_fn {
db: &$db_lt dyn $Db,
id: salsa::Id,
last_provisional_value: &Self::Output<$db_lt>,
value: &Self::Output<$db_lt>,
value: Self::Output<$db_lt>,
iteration_count: u32,
($($input_id),*): ($($interned_input_ty),*)
) -> $zalsa::CycleRecoveryAction<Self::Output<$db_lt>> {
) -> Self::Output<$db_lt> {
$($cycle_recovery_fn)*(db, id, last_provisional_value, value, iteration_count, $($input_id),*)
}

Expand Down
4 changes: 2 additions & 2 deletions components/salsa-macro-rules/src/unexpected_cycle_recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
#[macro_export]
macro_rules! unexpected_cycle_recovery {
($db:ident, $id:ident, $last_provisional_value:ident, $new_value:ident, $count:ident, $($other_inputs:ident),*) => {{
let (_db, _id, _last_provisional_value, _new_value, _count) = ($db, $id, $last_provisional_value, $new_value, $count);
let (_db, _id, _last_provisional_value, _count) = ($db, $id, $last_provisional_value, $count);
std::mem::drop(($($other_inputs,)*));
salsa::CycleRecoveryAction::Iterate
$new_value
}};
}

Expand Down
31 changes: 5 additions & 26 deletions src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
//!
//! When a query observes that it has just computed a result which contains itself as a cycle head,
//! it recognizes that it is responsible for resolving this cycle and calls its `cycle_fn` to
//! decide how to do so. The `cycle_fn` function is passed the provisional value just computed for
//! that query and the count of iterations so far, and must return either
//! `CycleRecoveryAction::Iterate` (which signals that the cycle head should re-iterate the cycle),
//! or `CycleRecoveryAction::Fallback` (which signals that the cycle head should replace its
//! computed value with the given fallback value).
//! decide what value to use. The `cycle_fn` function is passed the provisional value just computed
//! for that query and the count of iterations so far, and returns the value to use for this
//! iteration. This can be the computed value itself, or a different value (e.g., a fallback value).
//!
//! If the cycle head ever observes that the provisional value it just recomputed is the same as
//! the provisional value from the previous iteration, the cycle has converged. The cycle head will
//! If the cycle head ever observes that the value returned by `cycle_fn` is the same as the
//! provisional value from the previous iteration, this cycle has converged. The cycle head will
//! mark that value as final (by removing itself as cycle head) and return it.
//!
//! Other queries in the cycle will still have provisional values recorded, but those values should
Expand All @@ -39,11 +37,6 @@
//! of its cycle heads have a final result, in which case it, too, can be marked final. (This is
//! implemented in `shallow_verify_memo` and `validate_provisional`.)
//!
//! If the `cycle_fn` returns a fallback value, the cycle head will replace its provisional value
//! with that fallback, and then iterate the cycle one more time. A fallback value is expected to
//! result in a stable, converged cycle. If it does not (that is, if the result of another
//! iteration of the cycle is not the same as the fallback value), we'll panic.
//!
//! In nested cycle cases, the inner cycles are iterated as part of the outer cycle iteration. This helps
//! to significantly reduce the number of iterations needed to reach a fixpoint. For nested cycles,
//! the inner cycles head will transfer their lock ownership to the outer cycle. This ensures
Expand All @@ -64,20 +57,6 @@ use crate::Revision;
/// Should only be relevant in case of a badly configured cycle recovery.
pub const MAX_ITERATIONS: IterationCount = IterationCount(200);

/// Return value from a cycle recovery function.
#[derive(Debug)]
pub enum CycleRecoveryAction<T> {
/// Iterate the cycle again to look for a fixpoint.
Iterate,

/// Use the given value as the result for the current iteration instead
/// of the value computed by the query function.
///
/// Returning `Fallback` doesn't stop the fixpoint iteration. It only
/// allows the iterate function to return a different value.
Fallback(T),
}

/// Cycle recovery strategy: Is this query capable of recovering from
/// a cycle that results from executing the function? If so, how?
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
Expand Down
25 changes: 16 additions & 9 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::ptr::NonNull;
use std::sync::atomic::Ordering;
use std::sync::OnceLock;

use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy, IterationCount, ProvisionalStatus};
use crate::cycle::{CycleRecoveryStrategy, IterationCount, ProvisionalStatus};
use crate::database::RawDatabase;
use crate::function::delete::DeletedEntries;
use crate::hash::{FxHashSet, FxIndexSet};
Expand Down Expand Up @@ -91,9 +91,11 @@ pub trait Configuration: Any {
input: Self::Input<'db>,
) -> Self::Output<'db>;

/// Decide whether to iterate a cycle again or fallback. `value` is the provisional return
/// value from the latest iteration of this cycle. `count` is the number of cycle iterations
/// completed so far.
/// Decide what value to use for this cycle iteration. Takes ownership of the new value
/// and returns an owned value to use.
///
/// The function is called for every iteration of the cycle head, regardless of whether the cycle
/// has converged (the values are equal).
///
/// # Id
///
Expand All @@ -112,17 +114,22 @@ pub trait Configuration: Any {
/// * **Initial value**: `iteration` may be non-zero on the first call for a given query if that
/// query becomes the outermost cycle head after a nested cycle complete a few iterations. In this case,
/// `iteration` continues from the nested cycle's iteration count rather than resetting to zero.
/// * **Non-contiguous values**: This function isn't called if this cycle is part of an outer cycle
/// and the value for this query remains unchanged for one iteration. But the outer cycle might
/// keep iterating because other heads keep changing.
/// * **Non-contiguous values**: The iteration count can be non-contigious for cycle heads
/// that are only conditionally part of a cycle.
///
/// # Return value
///
/// The function should return the value to use for this iteration. This can be the `value`
/// that was computed, or a different value (e.g., a fallback value). This cycle will continue
/// iterating until the returned value equals the previous iteration's value.
fn recover_from_cycle<'db>(
db: &'db Self::DbView,
id: Id,
last_provisional_value: &Self::Output<'db>,
new_value: &Self::Output<'db>,
value: Self::Output<'db>,
iteration: u32,
input: Self::Input<'db>,
) -> CycleRecoveryAction<Self::Output<'db>>;
) -> Self::Output<'db>;

/// Serialize the output type using `serde`.
///
Expand Down
43 changes: 15 additions & 28 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,6 @@ where
I am a cycle head, comparing last provisional value with new value"
);

let mut this_converged = C::values_equal(&new_value, last_provisional_value);

// If this is the outermost cycle, use the maximum iteration count of all cycles.
// This is important for when later iterations introduce new cycle heads (that then
// become the outermost cycle). We want to ensure that the iteration count keeps increasing
Expand All @@ -373,36 +371,25 @@ where
iteration_count
};

if !this_converged {
// We are in a cycle that hasn't converged; ask the user's
// cycle-recovery function what to do:
match C::recover_from_cycle(
db,
id,
last_provisional_value,
&new_value,
iteration_count.as_u32(),
C::id_to_input(zalsa, id),
) {
crate::CycleRecoveryAction::Iterate => {}
crate::CycleRecoveryAction::Fallback(fallback_value) => {
tracing::debug!(
"{database_key_index:?}: execute: user cycle_fn says to fall back"
);
new_value = fallback_value;

this_converged = C::values_equal(&new_value, last_provisional_value);
}
}
// We are in a cycle that hasn't converged; ask the user's
// cycle-recovery function what to do (it may return the same value or a different one):
new_value = C::recover_from_cycle(
db,
id,
last_provisional_value,
new_value,
iteration_count.as_u32(),
C::id_to_input(zalsa, id),
);

let new_cycle_heads = active_query.take_cycle_heads();
for head in new_cycle_heads {
if !cycle_heads.contains(&head.database_key_index) {
panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index);
}
let new_cycle_heads = active_query.take_cycle_heads();
for head in new_cycle_heads {
if !cycle_heads.contains(&head.database_key_index) {
panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index);
}
}

let this_converged = C::values_equal(&new_value, last_provisional_value);
let mut completed_query = active_query.pop();

if let Some(outer_cycle) = outer_cycle {
Expand Down
8 changes: 4 additions & 4 deletions src/function/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ mod _memory_usage {
use crate::plumbing::{self, IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb};
use crate::table::memo::MemoTableWithTypes;
use crate::zalsa::Zalsa;
use crate::{CycleRecoveryAction, Database, Id, Revision};
use crate::{Database, Id, Revision};

use std::any::TypeId;
use std::num::NonZeroUsize;
Expand Down Expand Up @@ -564,11 +564,11 @@ mod _memory_usage {
_: &'db Self::DbView,
_: Id,
_: &Self::Output<'db>,
_: &Self::Output<'db>,
value: Self::Output<'db>,
_: u32,
_: Self::Input<'db>,
) -> CycleRecoveryAction<Self::Output<'db>> {
unimplemented!()
) -> Self::Output<'db> {
value
}

fn serialize<S>(_: &Self::Output<'_>, _: S) -> Result<S::Ok, S::Error>
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub use self::database::IngredientInfo;
pub use self::accumulator::Accumulator;
pub use self::active_query::Backtrace;
pub use self::cancelled::Cancelled;
pub use self::cycle::CycleRecoveryAction;

pub use self::database::Database;
pub use self::database_impl::DatabaseImpl;
pub use self::durability::Durability;
Expand Down Expand Up @@ -92,7 +92,7 @@ pub mod plumbing {
#[cfg(feature = "accumulator")]
pub use crate::accumulator::Accumulator;
pub use crate::attach::{attach, with_attached_database};
pub use crate::cycle::{CycleRecoveryAction, CycleRecoveryStrategy};
pub use crate::cycle::CycleRecoveryStrategy;
pub use crate::database::{current_revision, Database};
pub use crate::durability::Durability;
pub use crate::id::{AsId, FromId, FromIdWithDb, Id};
Expand Down
20 changes: 11 additions & 9 deletions tests/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
mod common;
use common::{ExecuteValidateLoggerDatabase, LogDatabase};
use expect_test::expect;
use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter};
use salsa::{Database as Db, DatabaseImpl as DbImpl, Durability, Setter};
#[cfg(not(miri))]
use test_log::test;

Expand Down Expand Up @@ -122,24 +122,26 @@ const MAX_ITERATIONS: u32 = 3;

/// Recover from a cycle by falling back to `Value::OutOfBounds` if the value is out of bounds,
/// `Value::TooManyIterations` if we've iterated more than `MAX_ITERATIONS` times, or else
/// iterating again.
/// returning the computed value to continue iterating.
fn cycle_recover(
_db: &dyn Db,
_id: salsa::Id,
_last_provisional_value: &Value,
value: &Value,
last_provisional_value: &Value,
value: Value,
count: u32,
_inputs: Inputs,
) -> CycleRecoveryAction<Value> {
if value
) -> Value {
if &value == last_provisional_value {
value
} else if value
.to_value()
.is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE)
{
CycleRecoveryAction::Fallback(Value::OutOfBounds)
Value::OutOfBounds
} else if count > MAX_ITERATIONS {
CycleRecoveryAction::Fallback(Value::TooManyIterations)
Value::TooManyIterations
} else {
CycleRecoveryAction::Iterate
value
}
}

Expand Down
6 changes: 3 additions & 3 deletions tests/cycle_accumulate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ fn cycle_fn(
_db: &dyn LogDatabase,
_id: salsa::Id,
_last_provisional_value: &[u32],
_value: &[u32],
value: Vec<u32>,
_count: u32,
_file: File,
) -> salsa::CycleRecoveryAction<Vec<u32>> {
salsa::CycleRecoveryAction::Iterate
) -> Vec<u32> {
value
}

#[test]
Expand Down
Loading