From 7dd8605ee3fca03b7dc7b460d07352a785b87515 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Fri, 17 Jun 2022 14:58:31 -0700 Subject: [PATCH] Backtrack execution for missing digests to make `eager_fetch=false` more resilient (#15850) As described in #11331, in order to avoid having to deal with missing remote content later in the pipeline, `--remote-cache-eager-fetch` currently defaults to true. This means that before calling a cache hit a hit, we fully download the output of the cache entry. In warm-cache situations, this can mean downloading a lot more than is strictly necessary. In theory, you could imagine `eager_fetch=False` downloading only stdio and no file content at all for a 100% cache hit rate run of tests. In practice, high hitrate runs [see about 80% fewer bytes downloaded, and 50% fewer RPCs](https://github.com/pantsbuild/pants/issues/11331#issuecomment-1130363438) than with `eager_fetch=True`. To begin moving toward disabling `eager_fetch` by default (and eventually, ideally, removing the flag entirely), this change begins "backtracking" when missing digests are encountered. Backtracking is implemented by "catching" `MissingDigest` errors (introduced in #15761), and invalidating their source `Node` in the graph. When a `Node` that produced a missing digest re-runs, it does so using progressively fewer caches (as introduced in #15854), in order to cache bust both local and remote partial cache entries. `eager_fetch=False` was already experimental, in that any `MissingDigest` error encountered later in the run would kill the entire run. Backtracking makes `eager_fetch=False` less experimental, in that we are now very likely to recover from a `MissingDigest` error. But it is still the case with `eager_fetch=False` that persistent remote infrastructure errors (those that last longer than our retry budget or timeout) could kill a run. Given that, we will likely want to gain more experience and further tune timeouts and retries before changing the default. Fixes #11331. [ci skip-build-wheels] --- pants.toml | 2 + .../pants/engine/internals/native_engine.pyi | 9 +- src/rust/engine/fs/store/src/remote_tests.rs | 8 +- src/rust/engine/fs/store/src/tests.rs | 4 +- .../engine/process_execution/src/cache.rs | 39 +- .../process_execution/src/cache_tests.rs | 1 + .../engine/process_execution/src/local.rs | 4 +- .../process_execution/src/nailgun/mod.rs | 7 +- .../src/remote_cache_tests.rs | 187 ++-- .../process_execution/src/remote_tests.rs | 22 +- src/rust/engine/src/context.rs | 133 ++- src/rust/engine/src/externs/testutil.rs | 24 +- src/rust/engine/src/intrinsics.rs | 2 +- src/rust/engine/src/nodes.rs | 67 +- src/rust/engine/testutil/Cargo.toml | 4 +- src/rust/engine/testutil/mock/Cargo.toml | 2 +- ...ction_cache.rs => action_cache_service.rs} | 116 +-- src/rust/engine/testutil/mock/src/cas.rs | 816 ++---------------- .../engine/testutil/mock/src/cas_service.rs | 685 +++++++++++++++ src/rust/engine/testutil/mock/src/lib.rs | 4 +- src/rust/engine/workunit_store/src/metrics.rs | 2 + .../remote_cache_integration_test.py | 105 ++- 22 files changed, 1240 insertions(+), 1003 deletions(-) rename src/rust/engine/testutil/mock/src/{action_cache.rs => action_cache_service.rs} (54%) create mode 100644 src/rust/engine/testutil/mock/src/cas_service.rs diff --git a/pants.toml b/pants.toml index 30efc9b33ce..0d420f1bdf5 100644 --- a/pants.toml +++ b/pants.toml @@ -71,6 +71,8 @@ unmatched_build_file_globs = "error" remote_store_address = "grpcs://cache.toolchain.com:443" remote_instance_name = "main" remote_auth_plugin = "toolchain.pants.auth.plugin:toolchain_auth_plugin" +# See https://github.com/pantsbuild/pants/issues/11331. +remote_cache_eager_fetch = false [anonymous-telemetry] enabled = true diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index 6f1604dbb96..2905c61994c 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -178,7 +178,8 @@ class PantsdClientException(Exception): # ------------------------------------------------------------------------------ class PyStubCASBuilder: - def always_errors(self) -> PyStubCASBuilder: ... + def ac_always_errors(self) -> PyStubCASBuilder: ... + def cas_always_errors(self) -> PyStubCASBuilder: ... def build(self, executor: PyExecutor) -> PyStubCAS: ... class PyStubCAS: @@ -186,6 +187,12 @@ class PyStubCAS: def builder(cls) -> PyStubCASBuilder: ... @property def address(self) -> str: ... + def remove(self, digest: FileDigest) -> bool: ... + def action_cache_len(self) -> int: ... + +# ------------------------------------------------------------------------------ +# (etc.) +# ------------------------------------------------------------------------------ class RawFdRunner(Protocol): def __call__( diff --git a/src/rust/engine/fs/store/src/remote_tests.rs b/src/rust/engine/fs/store/src/remote_tests.rs index 4c5c1a7729e..52363edc543 100644 --- a/src/rust/engine/fs/store/src/remote_tests.rs +++ b/src/rust/engine/fs/store/src/remote_tests.rs @@ -65,7 +65,7 @@ async fn missing_directory() { #[tokio::test] async fn load_file_grpc_error() { let _ = WorkunitStore::setup_for_tests(); - let cas = StubCAS::always_errors(); + let cas = StubCAS::cas_always_errors(); let error = load_file_bytes(&new_byte_store(&cas), TestData::roland().digest()) .await @@ -80,7 +80,7 @@ async fn load_file_grpc_error() { #[tokio::test] async fn load_directory_grpc_error() { let _ = WorkunitStore::setup_for_tests(); - let cas = StubCAS::always_errors(); + let cas = StubCAS::cas_always_errors(); let error = load_directory_proto_bytes( &new_byte_store(&cas), @@ -212,7 +212,7 @@ async fn write_empty_file() { #[tokio::test] async fn write_file_errors() { - let cas = StubCAS::always_errors(); + let cas = StubCAS::cas_always_errors(); let store = new_byte_store(&cas); let error = store @@ -290,7 +290,7 @@ async fn list_missing_digests_some_missing() { #[tokio::test] async fn list_missing_digests_error() { let _ = WorkunitStore::setup_for_tests(); - let cas = StubCAS::always_errors(); + let cas = StubCAS::cas_always_errors(); let store = new_byte_store(&cas); diff --git a/src/rust/engine/fs/store/src/tests.rs b/src/rust/engine/fs/store/src/tests.rs index 0a950b3d30b..da663cef717 100644 --- a/src/rust/engine/fs/store/src/tests.rs +++ b/src/rust/engine/fs/store/src/tests.rs @@ -278,7 +278,7 @@ async fn load_file_remote_error_is_error() { let dir = TempDir::new().unwrap(); let _ = WorkunitStore::setup_for_tests(); - let cas = StubCAS::always_errors(); + let cas = StubCAS::cas_always_errors(); let error = load_file_bytes( &new_store(dir.path(), &cas.address()), TestData::roland().digest(), @@ -303,7 +303,7 @@ async fn load_directory_remote_error_is_error() { let dir = TempDir::new().unwrap(); let _ = WorkunitStore::setup_for_tests(); - let cas = StubCAS::always_errors(); + let cas = StubCAS::cas_always_errors(); let error = new_store(dir.path(), &cas.address()) .load_directory(TestData::roland().digest()) .await diff --git a/src/rust/engine/process_execution/src/cache.rs b/src/rust/engine/process_execution/src/cache.rs index 8afebdbadda..baaf7bc73be 100644 --- a/src/rust/engine/process_execution/src/cache.rs +++ b/src/rust/engine/process_execution/src/cache.rs @@ -33,6 +33,7 @@ pub struct CommandRunner { inner: Arc, cache: PersistentCache, file_store: Store, + eager_fetch: bool, metadata: ProcessMetadata, } @@ -41,12 +42,14 @@ impl CommandRunner { inner: Arc, cache: PersistentCache, file_store: Store, + eager_fetch: bool, metadata: ProcessMetadata, ) -> CommandRunner { CommandRunner { inner, cache, file_store, + eager_fetch, metadata, } } @@ -196,22 +199,26 @@ impl CommandRunner { return Ok(None); }; - // Ensure that all digests in the result are loadable, erroring if any are not. - let _ = future::try_join_all(vec![ - self - .file_store - .ensure_local_has_file(result.stdout_digest) - .boxed(), - self - .file_store - .ensure_local_has_file(result.stderr_digest) - .boxed(), - self - .file_store - .ensure_local_has_recursive_directory(result.output_directory.clone()) - .boxed(), - ]) - .await?; + // If eager_fetch is enabled, ensure that all digests in the result are loadable, erroring + // if any are not. If eager_fetch is disabled, a Digest which is discovered to be missing later + // on during execution will cause backtracking. + if self.eager_fetch { + let _ = future::try_join_all(vec![ + self + .file_store + .ensure_local_has_file(result.stdout_digest) + .boxed(), + self + .file_store + .ensure_local_has_file(result.stderr_digest) + .boxed(), + self + .file_store + .ensure_local_has_recursive_directory(result.output_directory.clone()) + .boxed(), + ]) + .await?; + } Ok(Some(result)) } diff --git a/src/rust/engine/process_execution/src/cache_tests.rs b/src/rust/engine/process_execution/src/cache_tests.rs index eeb185b1ed4..05f41059002 100644 --- a/src/rust/engine/process_execution/src/cache_tests.rs +++ b/src/rust/engine/process_execution/src/cache_tests.rs @@ -58,6 +58,7 @@ fn create_cached_runner( local.into(), cache, store, + true, ProcessMetadata::default(), )); diff --git a/src/rust/engine/process_execution/src/local.rs b/src/rust/engine/process_execution/src/local.rs index 9f9fea2e219..2e94ab43ec1 100644 --- a/src/rust/engine/process_execution/src/local.rs +++ b/src/rust/engine/process_execution/src/local.rs @@ -483,7 +483,7 @@ pub trait CapturedWorkdir { workdir_token: Self::WorkdirToken, exclusive_spawn: bool, platform: Platform, - ) -> Result { + ) -> Result { let start_time = Instant::now(); // Spawn the process. @@ -578,7 +578,7 @@ pub trait CapturedWorkdir { metadata: result_metadata, }) } - Err(msg) => Err(msg.into()), + Err(msg) => Err(msg), } } diff --git a/src/rust/engine/process_execution/src/nailgun/mod.rs b/src/rust/engine/process_execution/src/nailgun/mod.rs index a69cdf0c67a..fd1693491dc 100644 --- a/src/rust/engine/process_execution/src/nailgun/mod.rs +++ b/src/rust/engine/process_execution/src/nailgun/mod.rs @@ -149,7 +149,8 @@ impl super::CommandRunner for CommandRunner { client_args, client_main_class, .. - } = ParsedJVMCommandLines::parse_command_lines(&req.argv)?; + } = ParsedJVMCommandLines::parse_command_lines(&req.argv) + .map_err(ProcessError::Unclassified)?; let nailgun_name = CommandRunner::calculate_nailgun_name(&client_main_class); let (client_input_digests, server_input_digests) = @@ -173,7 +174,7 @@ impl super::CommandRunner for CommandRunner { self.inner.immutable_inputs(), ) .await - .map_err(|e| format!("Failed to connect to nailgun! {}", e))?; + .map_err(|e| e.enrich("Failed to connect to nailgun"))?; // Prepare the workdir. let exclusive_spawn = prepare_workdir( @@ -204,7 +205,7 @@ impl super::CommandRunner for CommandRunner { // release, it assumes that it has been canceled and kills the server. nailgun_process.release().await?; - res + Ok(res?) } ) .await diff --git a/src/rust/engine/process_execution/src/remote_cache_tests.rs b/src/rust/engine/process_execution/src/remote_cache_tests.rs index 392cbdeab2e..9a5a4100ae5 100644 --- a/src/rust/engine/process_execution/src/remote_cache_tests.rs +++ b/src/rust/engine/process_execution/src/remote_cache_tests.rs @@ -5,17 +5,17 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; +use maplit::hashset; +use tempfile::TempDir; +use tokio::time::sleep; + use fs::{DirectoryDigest, RelativePath, EMPTY_DIRECTORY_DIGEST}; use grpc_util::tls; use hashing::{Digest, EMPTY_DIGEST}; -use maplit::hashset; -use mock::{StubActionCache, StubCAS}; +use mock::StubCAS; use protos::gen::build::bazel::remote::execution::v2 as remexec; -use remexec::ActionResult; use store::Store; -use tempfile::TempDir; use testutil::data::{TestData, TestDirectory, TestTree}; -use tokio::time::sleep; use workunit_store::{RunId, RunningWorkunit, WorkunitStore}; use crate::remote::{ensure_action_stored_locally, make_execute_request}; @@ -74,14 +74,17 @@ impl CommandRunnerTrait for MockLocalCommandRunner { struct StoreSetup { pub store: Store, pub _store_temp_dir: TempDir, - pub _cas: StubCAS, + pub cas: StubCAS, pub executor: task_executor::Executor, } impl StoreSetup { - pub fn new() -> StoreSetup { + pub fn new() -> Self { + Self::new_with_stub_cas(StubCAS::builder().build()) + } + + pub fn new_with_stub_cas(cas: StubCAS) -> Self { let executor = task_executor::Executor::new(); - let cas = StubCAS::builder().build(); let store_temp_dir = TempDir::new().unwrap(); let store_dir = store_temp_dir.path().join("store_dir"); let store = Store::local_only(executor.clone(), store_dir) @@ -99,10 +102,10 @@ impl StoreSetup { 4 * 1024 * 1024, ) .unwrap(); - StoreSetup { + Self { store, _store_temp_dir: store_temp_dir, - _cas: cas, + cas, executor, } } @@ -124,18 +127,15 @@ fn create_local_runner( fn create_cached_runner( local: Box, store_setup: &StoreSetup, - read_delay_ms: u64, - write_delay_ms: u64, eager_fetch: bool, -) -> (Box, StubActionCache) { - let action_cache = StubActionCache::new_with_delays(read_delay_ms, write_delay_ms).unwrap(); - let runner = Box::new( +) -> Box { + Box::new( crate::remote_cache::CommandRunner::new( local.into(), ProcessMetadata::default(), store_setup.executor.clone(), store_setup.store.clone(), - &action_cache.address(), + &store_setup.cas.address(), None, BTreeMap::default(), Platform::current().unwrap(), @@ -147,54 +147,40 @@ fn create_cached_runner( CACHE_READ_TIMEOUT, ) .expect("caching command runner"), - ); - (runner, action_cache) + ) } -async fn create_process(store: &Store) -> (Process, Digest) { +// TODO: Unfortunately, this code cannot be moved to the `testutil::mock` crate, because that +// introduces a cycle between this crate and that one. +async fn create_process(store_setup: &StoreSetup) -> (Process, Digest) { let process = Process::new(vec![ "this process will not execute: see MockLocalCommandRunner".to_string(), ]); let (action, command, _exec_request) = make_execute_request(&process, ProcessMetadata::default()).unwrap(); - let (_command_digest, action_digest) = ensure_action_stored_locally(store, &command, &action) - .await - .unwrap(); + let (_command_digest, action_digest) = + ensure_action_stored_locally(&store_setup.store, &command, &action) + .await + .unwrap(); (process, action_digest) } -fn insert_into_action_cache( - action_cache: &StubActionCache, - action_digest: &Digest, - exit_code: i32, - stdout_digest: Digest, - stderr_digest: Digest, -) { - let action_result = ActionResult { - exit_code, - stdout_digest: Some(stdout_digest.into()), - stderr_digest: Some(stderr_digest.into()), - ..ActionResult::default() - }; - action_cache - .action_map - .lock() - .insert(action_digest.hash, action_result); -} - #[tokio::test] async fn cache_read_success() { let (_, mut workunit) = WorkunitStore::setup_for_tests(); let store_setup = StoreSetup::new(); let (local_runner, local_runner_call_counter) = create_local_runner(1, 1000); - let (cache_runner, action_cache) = create_cached_runner(local_runner, &store_setup, 0, 0, false); + let cache_runner = create_cached_runner(local_runner, &store_setup, false); - let (process, action_digest) = create_process(&store_setup.store).await; - insert_into_action_cache(&action_cache, &action_digest, 0, EMPTY_DIGEST, EMPTY_DIGEST); + let (process, action_digest) = create_process(&store_setup).await; + store_setup + .cas + .action_cache + .insert(action_digest, 0, EMPTY_DIGEST, EMPTY_DIGEST); assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); let remote_result = cache_runner - .run(Context::default(), &mut workunit, process.clone().into()) + .run(Context::default(), &mut workunit, process.into()) .await .unwrap(); assert_eq!(remote_result.exit_code, 0); @@ -208,11 +194,18 @@ async fn cache_read_skipped_on_action_cache_errors() { let (workunit_store, mut workunit) = WorkunitStore::setup_for_tests(); let store_setup = StoreSetup::new(); let (local_runner, local_runner_call_counter) = create_local_runner(1, 500); - let (cache_runner, action_cache) = create_cached_runner(local_runner, &store_setup, 0, 0, false); - - let (process, action_digest) = create_process(&store_setup.store).await; - insert_into_action_cache(&action_cache, &action_digest, 0, EMPTY_DIGEST, EMPTY_DIGEST); - action_cache.always_errors.store(true, Ordering::SeqCst); + let cache_runner = create_cached_runner(local_runner, &store_setup, false); + + let (process, action_digest) = create_process(&store_setup).await; + store_setup + .cas + .action_cache + .insert(action_digest, 0, EMPTY_DIGEST, EMPTY_DIGEST); + store_setup + .cas + .action_cache + .always_errors + .store(true, Ordering::SeqCst); assert_eq!( workunit_store.get_metrics().get("remote_cache_read_errors"), @@ -220,7 +213,7 @@ async fn cache_read_skipped_on_action_cache_errors() { ); assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); let remote_result = cache_runner - .run(Context::default(), &mut workunit, process.clone().into()) + .run(Context::default(), &mut workunit, process.into()) .await .unwrap(); assert_eq!(remote_result.exit_code, 1); @@ -238,13 +231,12 @@ async fn cache_read_skipped_on_store_errors() { let (workunit_store, mut workunit) = WorkunitStore::setup_for_tests(); let store_setup = StoreSetup::new(); let (local_runner, local_runner_call_counter) = create_local_runner(1, 500); - let (cache_runner, action_cache) = create_cached_runner(local_runner, &store_setup, 0, 0, true); + let cache_runner = create_cached_runner(local_runner, &store_setup, true); // Claim that the process has a non-empty and not-persisted stdout digest. - let (process, action_digest) = create_process(&store_setup.store).await; - insert_into_action_cache( - &action_cache, - &action_digest, + let (process, action_digest) = create_process(&store_setup).await; + store_setup.cas.action_cache.insert( + action_digest, 0, Digest::of_bytes("pigs flying".as_bytes()), EMPTY_DIGEST, @@ -256,7 +248,7 @@ async fn cache_read_skipped_on_store_errors() { ); assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); let remote_result = cache_runner - .run(Context::default(), &mut workunit, process.clone().into()) + .run(Context::default(), &mut workunit, process.into()) .await .unwrap(); assert_eq!(remote_result.exit_code, 1); @@ -277,13 +269,11 @@ async fn cache_read_eager_fetch() { async fn run_process(eager_fetch: bool, workunit: &mut RunningWorkunit) -> (i32, usize) { let store_setup = StoreSetup::new(); let (local_runner, local_runner_call_counter) = create_local_runner(1, 1000); - let (cache_runner, action_cache) = - create_cached_runner(local_runner, &store_setup, 0, 0, eager_fetch); + let cache_runner = create_cached_runner(local_runner, &store_setup, eager_fetch); - let (process, action_digest) = create_process(&store_setup.store).await; - insert_into_action_cache( - &action_cache, - &action_digest, + let (process, action_digest) = create_process(&store_setup).await; + store_setup.cas.action_cache.insert( + action_digest, 0, TestData::roland().digest(), TestData::roland().digest(), @@ -291,7 +281,7 @@ async fn cache_read_eager_fetch() { assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); let remote_result = cache_runner - .run(Context::default(), workunit, process.clone().into()) + .run(Context::default(), workunit, process.into()) .await .unwrap(); @@ -318,19 +308,25 @@ async fn cache_read_speculation() { cache_hit: bool, workunit: &mut RunningWorkunit, ) -> (i32, usize) { - let store_setup = StoreSetup::new(); + let store_setup = StoreSetup::new_with_stub_cas( + StubCAS::builder() + .ac_read_delay(Duration::from_millis(remote_delay_ms)) + .build(), + ); let (local_runner, local_runner_call_counter) = create_local_runner(1, local_delay_ms); - let (cache_runner, action_cache) = - create_cached_runner(local_runner, &store_setup, remote_delay_ms, 0, false); + let cache_runner = create_cached_runner(local_runner, &store_setup, false); - let (process, action_digest) = create_process(&store_setup.store).await; + let (process, action_digest) = create_process(&store_setup).await; if cache_hit { - insert_into_action_cache(&action_cache, &action_digest, 0, EMPTY_DIGEST, EMPTY_DIGEST); + store_setup + .cas + .action_cache + .insert(action_digest, 0, EMPTY_DIGEST, EMPTY_DIGEST); } assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); let remote_result = cache_runner - .run(Context::default(), workunit, process.clone().into()) + .run(Context::default(), workunit, process.into()) .await .unwrap(); @@ -359,11 +355,11 @@ async fn cache_write_success() { let (_, mut workunit) = WorkunitStore::setup_for_tests(); let store_setup = StoreSetup::new(); let (local_runner, local_runner_call_counter) = create_local_runner(0, 100); - let (cache_runner, action_cache) = create_cached_runner(local_runner, &store_setup, 0, 0, false); - let (process, action_digest) = create_process(&store_setup.store).await; + let cache_runner = create_cached_runner(local_runner, &store_setup, false); + let (process, action_digest) = create_process(&store_setup).await; assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); - assert!(action_cache.action_map.lock().is_empty()); + assert!(store_setup.cas.action_cache.action_map.lock().is_empty()); let local_result = cache_runner .run(Context::default(), &mut workunit, process.clone().into()) @@ -374,11 +370,12 @@ async fn cache_write_success() { // Wait for the cache write block to finish. sleep(Duration::from_secs(1)).await; - assert_eq!(action_cache.action_map.lock().len(), 1); - let action_map_mutex_guard = action_cache.action_map.lock(); + assert_eq!(store_setup.cas.action_cache.len(), 1); assert_eq!( - action_map_mutex_guard - .get(&action_digest.hash) + store_setup + .cas + .action_cache + .get(action_digest) .unwrap() .exit_code, 0 @@ -390,11 +387,11 @@ async fn cache_write_not_for_failures() { let (_, mut workunit) = WorkunitStore::setup_for_tests(); let store_setup = StoreSetup::new(); let (local_runner, local_runner_call_counter) = create_local_runner(1, 100); - let (cache_runner, action_cache) = create_cached_runner(local_runner, &store_setup, 0, 0, false); - let (process, _action_digest) = create_process(&store_setup.store).await; + let cache_runner = create_cached_runner(local_runner, &store_setup, false); + let (process, _action_digest) = create_process(&store_setup).await; assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); - assert!(action_cache.action_map.lock().is_empty()); + assert!(store_setup.cas.action_cache.action_map.lock().is_empty()); let local_result = cache_runner .run(Context::default(), &mut workunit, process.clone().into()) @@ -405,21 +402,24 @@ async fn cache_write_not_for_failures() { // Wait for the cache write block to finish. sleep(Duration::from_millis(100)).await; - assert!(action_cache.action_map.lock().is_empty()); + assert!(store_setup.cas.action_cache.action_map.lock().is_empty()); } /// Cache writes should be async and not block the CommandRunner from returning. #[tokio::test] async fn cache_write_does_not_block() { let (_, mut workunit) = WorkunitStore::setup_for_tests(); - let store_setup = StoreSetup::new(); + let store_setup = StoreSetup::new_with_stub_cas( + StubCAS::builder() + .ac_write_delay(Duration::from_millis(100)) + .build(), + ); let (local_runner, local_runner_call_counter) = create_local_runner(0, 100); - let (cache_runner, action_cache) = - create_cached_runner(local_runner, &store_setup, 0, 100, false); - let (process, action_digest) = create_process(&store_setup.store).await; + let cache_runner = create_cached_runner(local_runner, &store_setup, false); + let (process, action_digest) = create_process(&store_setup).await; assert_eq!(local_runner_call_counter.load(Ordering::SeqCst), 0); - assert!(action_cache.action_map.lock().is_empty()); + assert!(store_setup.cas.action_cache.action_map.lock().is_empty()); let local_result = cache_runner .run(Context::default(), &mut workunit, process.clone().into()) @@ -430,14 +430,15 @@ async fn cache_write_does_not_block() { // We expect the cache write to have not finished yet, even though we already finished // CommandRunner::run(). - assert!(action_cache.action_map.lock().is_empty()); + assert!(store_setup.cas.action_cache.action_map.lock().is_empty()); sleep(Duration::from_secs(1)).await; - assert_eq!(action_cache.action_map.lock().len(), 1); - let action_map_mutex_guard = action_cache.action_map.lock(); + assert_eq!(store_setup.cas.action_cache.len(), 1); assert_eq!( - action_map_mutex_guard - .get(&action_digest.hash) + store_setup + .cas + .action_cache + .get(action_digest) .unwrap() .exit_code, 0 @@ -594,13 +595,13 @@ async fn make_action_result_basic() { .expect("Error saving directory"); let mock_command_runner = Arc::new(MockCommandRunner); - let action_cache = StubActionCache::new().unwrap(); + let cas = StubCAS::builder().build(); let runner = crate::remote_cache::CommandRunner::new( mock_command_runner.clone(), ProcessMetadata::default(), executor.clone(), store.clone(), - &action_cache.address(), + &cas.address(), None, BTreeMap::default(), Platform::current().unwrap(), diff --git a/src/rust/engine/process_execution/src/remote_tests.rs b/src/rust/engine/process_execution/src/remote_tests.rs index 637c104c509..ec6d91bbf88 100644 --- a/src/rust/engine/process_execution/src/remote_tests.rs +++ b/src/rust/engine/process_execution/src/remote_tests.rs @@ -1711,9 +1711,9 @@ async fn remote_workunits_are_stored() { .file(&TestData::roland()) .directory(&TestDirectory::containing_roland()) .build(); - let action_cache = mock::StubActionCache::new().unwrap(); - let (command_runner, _store) = - create_command_runner(action_cache.address(), &cas, Platform::Linux_x86_64); + // TODO: This CommandRunner is only used for parsing, add so intentionally passes a CAS/AC + // address rather than an Execution address. + let (command_runner, _store) = create_command_runner(cas.address(), &cas, Platform::Linux_x86_64); command_runner .extract_execute_response(RunId(0), OperationOrStatus::Operation(operation)) @@ -2161,7 +2161,7 @@ pub(crate) async fn run_cmd_runner( } fn create_command_runner( - address: String, + execution_address: String, cas: &mock::StubCAS, platform: Platform, ) -> (CommandRunner, Store) { @@ -2169,7 +2169,7 @@ fn create_command_runner( let store_dir = TempDir::new().unwrap(); let store = make_store(store_dir.path(), cas, runtime.clone()); let command_runner = CommandRunner::new( - &address, + &execution_address, ProcessMetadata::default(), None, BTreeMap::new(), @@ -2185,7 +2185,7 @@ fn create_command_runner( } async fn run_command_remote( - address: String, + execution_address: String, request: Process, ) -> Result { let (_, mut workunit) = WorkunitStore::setup_for_tests(); @@ -2194,7 +2194,8 @@ async fn run_command_remote( .directory(&TestDirectory::containing_roland()) .tree(&TestTree::roland_at_root()) .build(); - let (command_runner, store) = create_command_runner(address, &cas, Platform::Linux_x86_64); + let (command_runner, store) = + create_command_runner(execution_address, &cas, Platform::Linux_x86_64); let original = command_runner .run(Context::default(), &mut workunit, request) .await?; @@ -2238,14 +2239,13 @@ async fn extract_execute_response( operation: Operation, remote_platform: Platform, ) -> Result { - let action_cache = mock::StubActionCache::new().expect("failed to create action cache"); - let cas = mock::StubCAS::builder() .file(&TestData::roland()) .directory(&TestDirectory::containing_roland()) .build(); - let (command_runner, store) = - create_command_runner(action_cache.address(), &cas, remote_platform); + // TODO: This CommandRunner is only used for parsing, add so intentionally passes a CAS/AC + // address rather than an Execution address. + let (command_runner, store) = create_command_runner(cas.address(), &cas, remote_platform); let original = command_runner .extract_execute_response(RunId(0), OperationOrStatus::Operation(operation)) diff --git a/src/rust/engine/src/context.rs b/src/rust/engine/src/context.rs index 18c0a73bd67..798965c8a50 100644 --- a/src/rust/engine/src/context.rs +++ b/src/rust/engine/src/context.rs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0 (see LICENSE). use std::cmp::max; -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::convert::{Into, TryInto}; use std::future::Future; use std::io::Read; @@ -12,7 +12,7 @@ use std::sync::Arc; use std::time::Duration; use crate::intrinsics::Intrinsics; -use crate::nodes::{NodeKey, WrappedNode}; +use crate::nodes::{ExecuteProcess, NodeKey, NodeOutput, NodeResult, WrappedNode}; use crate::python::Failure; use crate::session::{Session, Sessions}; use crate::tasks::{Rule, Tasks}; @@ -34,7 +34,7 @@ use rule_graph::RuleGraph; use store::{self, Store}; use task_executor::Executor; use watch::{Invalidatable, InvalidationWatcher}; -use workunit_store::RunId; +use workunit_store::{Metric, RunId, RunningWorkunit}; // The reqwest crate has no support for ingesting multiple certificates in a single file, // and requires single PEM blocks. There is a crate (https://crates.io/crates/pem) that can decode @@ -266,12 +266,14 @@ impl Core { inner_runner: Arc, full_store: &Store, local_cache: &PersistentCache, + eager_fetch: bool, process_execution_metadata: &ProcessMetadata, ) -> Arc { Arc::new(process_execution::cache::CommandRunner::new( inner_runner, local_cache.clone(), full_store.clone(), + eager_fetch, process_execution_metadata.clone(), )) } @@ -363,23 +365,19 @@ impl Core { None }; - // TODO: The local cache eagerly fetches outputs independent of the `eager_fetch` flag. Once - // `eager_fetch` backtracks via https://github.com/pantsbuild/pants/issues/11331, the local - // cache will be able to obey `eager_fetch` as well, and can efficiently be used with remote - // execution. - let maybe_local_cached_runner = - if exec_strategy_opts.local_cache && !remoting_opts.execution_enable { - Some(Self::make_local_cached_runner( - maybe_remote_cached_runner - .clone() - .unwrap_or_else(|| leaf_runner.clone()), - full_store, - local_cache, - process_execution_metadata, - )) - } else { - None - }; + let maybe_local_cached_runner = if exec_strategy_opts.local_cache { + Some(Self::make_local_cached_runner( + maybe_remote_cached_runner + .clone() + .unwrap_or_else(|| leaf_runner.clone()), + full_store, + local_cache, + remoting_opts.cache_eager_fetch, + process_execution_metadata, + )) + } else { + None + }; Ok( vec![ @@ -650,6 +648,12 @@ pub struct Context { pub core: Arc, pub session: Session, run_id: RunId, + /// The number of attempts which have been made to backtrack to a particular ExecuteProcess node. + /// + /// Presence in this map at process runtime indicates that the pricess is being retried, and that + /// there was something invalid or unusable about previous attempts. Successive attempts should + /// run in a different mode (skipping caches, etc) to attempt to produce a valid result. + backtrack_attempts: Arc>>, stats: Arc>, } @@ -661,6 +665,7 @@ impl Context { core, session, run_id, + backtrack_attempts: Arc::default(), stats: Arc::default(), } } @@ -668,7 +673,7 @@ impl Context { /// /// Get the future value for the given Node implementation. /// - pub async fn get(&self, node: N) -> Result { + pub async fn get(&self, node: N) -> NodeResult { let node_result = self .core .graph @@ -680,6 +685,91 @@ impl Context { .unwrap_or_else(|_| panic!("A Node implementation was ambiguous.")), ) } + + /// + /// If the given Result is a Failure::MissingDigest, attempts to invalidate the Node which was + /// the source of the Digest, potentially causing indirect retry of the Result. + /// + /// If we successfully locate and restart the source of the Digest, converts the Result into a + /// `Failure::Invalidated`, which will cause retry at some level above us. + /// + pub fn maybe_backtrack( + &self, + result: NodeResult, + workunit: &mut RunningWorkunit, + ) -> NodeResult { + let digest = if let Err(Failure::MissingDigest(_, d)) = result.as_ref() { + *d + } else { + return result; + }; + + // Locate the source(s) of this Digest. + // TODO: Currently needs a combination of `visit_live` and `invalidate_from_roots` because + // `invalidate_from_roots` cannot view `Node` results. This could lead to a race condition + // where a `Node` is invalidated multiple times, which might cause it to increment its attempt + // count multiple times. See https://github.com/pantsbuild/pants/issues/15867 + let mut roots = HashSet::new(); + self.core.graph.visit_live(self, |k, v| match k { + NodeKey::ExecuteProcess(p) if v.digests().contains(&digest) => { + roots.insert(p.clone()); + } + _ => (), + }); + + if roots.is_empty() { + // We did not identify any roots to invalidate: allow the Node to fail. + return result; + } + + // Trigger backtrack attempts for the matched Nodes. + { + let mut backtrack_attempts = self.backtrack_attempts.lock(); + for root in &roots { + let attempt = backtrack_attempts.entry((**root).clone()).or_insert(1); + let description = &root.process.description; + workunit.increment_counter(Metric::BacktrackAttempts, 1); + log::warn!( + "Making attempt {attempt} to backtrack and retry `{description}`, due to \ + missing digest {digest:?}." + ); + } + } + + // Invalidate the matched roots. + self + .core + .graph + .invalidate_from_roots(move |node| match node { + NodeKey::ExecuteProcess(p) => roots.contains(p), + _ => false, + }); + + // We invalidated a Node, and the caller (at some level above us in the stack) should retry. + // Complete this node with the Invalidated state. + // TODO: Differentiate the reasons for Invalidation (filesystem changes vs missing digests) to + // improve warning messages. See https://github.com/pantsbuild/pants/issues/15867 + Err(Failure::Invalidated) + } + + /// + /// Called before executing a process to determine whether it is backtracking, and if so, to + /// increment the attempt count. + /// + /// A process which has not been marked backtracking will always return 0, regardless of the + /// number of calls to this method. + /// + pub fn maybe_start_backtracking(&self, node: &ExecuteProcess) -> usize { + let mut backtrack_attempts = self.backtrack_attempts.lock(); + let entry: Option<&mut usize> = backtrack_attempts.get_mut(node); + if let Some(entry) = entry { + let attempt = *entry; + *entry += 1; + attempt + } else { + 0 + } + } } impl NodeContext for Context { @@ -700,6 +790,7 @@ impl NodeContext for Context { core: self.core.clone(), session: self.session.clone(), run_id: self.run_id, + backtrack_attempts: self.backtrack_attempts.clone(), stats: self.stats.clone(), } } diff --git a/src/rust/engine/src/externs/testutil.rs b/src/rust/engine/src/externs/testutil.rs index ceedeab5296..19edb1eae55 100644 --- a/src/rust/engine/src/externs/testutil.rs +++ b/src/rust/engine/src/externs/testutil.rs @@ -8,9 +8,11 @@ use pyo3::exceptions::PyAssertionError; use pyo3::prelude::*; use pyo3::types::PyType; -use crate::externs::scheduler::PyExecutor; use testutil_mock::{StubCAS, StubCASBuilder}; +use crate::externs::fs::PyFileDigest; +use crate::externs::scheduler::PyExecutor; + pub fn register(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; @@ -22,12 +24,20 @@ struct PyStubCASBuilder(Arc>>); #[pymethods] impl PyStubCASBuilder { - fn always_errors(&mut self) -> PyResult { + fn ac_always_errors(&mut self) -> PyResult { let mut builder_opt = self.0.lock(); let builder = builder_opt .take() .ok_or_else(|| PyAssertionError::new_err("Unable to unwrap StubCASBuilder"))?; - *builder_opt = Some(builder.always_errors()); + *builder_opt = Some(builder.ac_always_errors()); + Ok(PyStubCASBuilder(self.0.clone())) + } + fn cas_always_errors(&mut self) -> PyResult { + let mut builder_opt = self.0.lock(); + let builder = builder_opt + .take() + .ok_or_else(|| PyAssertionError::new_err("Unable to unwrap StubCASBuilder"))?; + *builder_opt = Some(builder.cas_always_errors()); Ok(PyStubCASBuilder(self.0.clone())) } @@ -56,4 +66,12 @@ impl PyStubCAS { fn address(&self) -> String { self.0.address() } + + fn remove(&self, digest: PyFileDigest) -> bool { + self.0.remove(digest.0.hash) + } + + fn action_cache_len(&self) -> usize { + self.0.action_cache.len() + } } diff --git a/src/rust/engine/src/intrinsics.rs b/src/rust/engine/src/intrinsics.rs index d8c690526c6..3f7ee6514a5 100644 --- a/src/rust/engine/src/intrinsics.rs +++ b/src/rust/engine/src/intrinsics.rs @@ -176,7 +176,7 @@ fn process_request_to_process_result( ) -> BoxFuture<'static, NodeResult> { async move { let process_request = ExecuteProcess::lift(&context.core.store(), args.pop().unwrap()) - .map_err(|e| throw(format!("Error lifting Process: {}", e))) + .map_err(|e| e.enrich("Error lifting Process")) .await?; let result = context.get(process_request).await?.0; diff --git a/src/rust/engine/src/nodes.rs b/src/rust/engine/src/nodes.rs index cd4c1345b83..b142cced741 100644 --- a/src/rust/engine/src/nodes.rs +++ b/src/rust/engine/src/nodes.rs @@ -40,7 +40,7 @@ use crate::externs::engine_aware::{EngineAwareParameter, EngineAwareReturnType}; use crate::externs::fs::PyFileDigest; use graph::{Entry, Node, NodeError, NodeVisualizer}; use hashing::Digest; -use store::{self, Store, StoreFileByDigest}; +use store::{self, Store, StoreError, StoreFileByDigest}; use workunit_store::{ in_workunit, Level, Metric, ObservationMetric, RunningWorkunit, UserMetadataItem, WorkunitMetadata, @@ -261,14 +261,14 @@ pub fn lift_file_digest(digest: &PyAny) -> Result { /// #[derive(Clone, Debug, DeepSizeOf, Eq, Hash, PartialEq)] pub struct ExecuteProcess { - process: Process, + pub process: Process, } impl ExecuteProcess { async fn lift_process_input_digests( store: &Store, value: &Value, - ) -> Result { + ) -> Result { let input_digests_fut: Result<_, String> = Python::with_gil(|py| { let value = (**value).as_ref(py); let input_files = lift_directory_digest(externs::getattr(value, "input_digest").unwrap()) @@ -294,10 +294,10 @@ impl ExecuteProcess { input_digests_fut? .await - .map_err(|e| format!("Failed to merge input digests for process: {}", e)) + .map_err(|e| e.enrich("Failed to merge input digests for process")) } - fn lift_process(value: &PyAny, input_digests: InputDigests) -> Result { + pub fn lift_process(value: &PyAny, input_digests: InputDigests) -> Result { let env = externs::getattr_from_str_frozendict(value, "env"); let working_directory = match externs::getattr_as_optional_string(value, "working_directory") { None => None, @@ -374,7 +374,7 @@ impl ExecuteProcess { }) } - pub async fn lift(store: &Store, value: Value) -> Result { + pub async fn lift(store: &Store, value: Value) -> Result { let input_digests = Self::lift_process_input_digests(store, &value).await?; let process = Python::with_gil(|py| Self::lift_process((*value).as_ref(py), input_digests))?; Ok(Self { process }) @@ -384,10 +384,18 @@ impl ExecuteProcess { self, context: Context, workunit: &mut RunningWorkunit, + attempt: usize, ) -> NodeResult { let request = self.process; - let command_runner = &context.core.command_runners[0]; + let command_runner = context.core.command_runners.get(attempt).ok_or_else(|| { + // NB: We only backtrack for a Process if it produces a Digest which cannot be consumed + // from disk: if we've fallen all the way back to local execution, and even that + // produces an unreadable Digest, then there is a fundamental implementation issue. + throw(format!( + "Process {request:?} produced an invalid result on all configured command runners." + )) + })?; let execution_context = process_execution::Context::new( context.session.workunit_store(), @@ -1296,6 +1304,18 @@ impl NodeKey { } } + async fn maybe_watch(&self, context: &Context) -> NodeResult<()> { + if let Some((path, watcher)) = self.fs_subject().zip(context.core.watcher.as_ref()) { + let abs_path = context.core.build_root.join(path); + watcher + .watch(abs_path) + .map_err(|e| Context::mk_error(&e)) + .await + } else { + Ok(()) + } + } + /// /// Filters the given Params to those which are subtypes of EngineAwareParameter. /// @@ -1347,28 +1367,19 @@ impl Node for NodeKey { .collect() }, |workunit| async move { - // To avoid races, we must ensure that we have installed a watch for the subject before - // executing the node logic. But in case of failure, we wait to see if the Node itself - // fails, and prefer that error message if so (because we have little control over the - // error messages of the watch API). - let maybe_watch = - if let Some((path, watcher)) = self.fs_subject().zip(context.core.watcher.as_ref()) { - let abs_path = context.core.build_root.join(path); - watcher - .watch(abs_path) - .map_err(|e| Context::mk_error(&e)) - .await - } else { - Ok(()) - }; + // Ensure that we have installed filesystem watches before Nodes which inspect the + // filesystem. + let maybe_watch = self.maybe_watch(&context).await; let mut result = match self { NodeKey::DigestFile(n) => n.run_node(context).await.map(NodeOutput::FileDigest), NodeKey::DownloadedFile(n) => n.run_node(context).await.map(NodeOutput::Snapshot), - NodeKey::ExecuteProcess(n) => n - .run_node(context, workunit) - .await - .map(|r| NodeOutput::ProcessResult(Box::new(r))), + NodeKey::ExecuteProcess(n) => { + let attempt = context.maybe_start_backtracking(&n); + n.run_node(context, workunit, attempt) + .await + .map(|r| NodeOutput::ProcessResult(Box::new(r))) + } NodeKey::ReadLink(n) => n.run_node(context).await.map(NodeOutput::LinkDest), NodeKey::Scandir(n) => n.run_node(context).await.map(NodeOutput::DirectoryListing), NodeKey::Select(n) => n.run_node(context).await.map(NodeOutput::Value), @@ -1379,7 +1390,11 @@ impl Node for NodeKey { NodeKey::Task(n) => n.run_node(context, workunit).await.map(NodeOutput::Value), }; - // If both the Node and the watch failed, prefer the Node's error message. + // If the Node failed with MissingDigest, attempt to invalidate the source of the Digest. + result = context2.maybe_backtrack(result, workunit); + + // If both the Node and the watch failed, prefer the Node's error message (we have little + // control over the error messages of the watch API). match (&result, maybe_watch) { (Ok(_), Ok(_)) => {} (Err(_), _) => {} diff --git a/src/rust/engine/testutil/Cargo.toml b/src/rust/engine/testutil/Cargo.toml index 4d854422a4a..c7261b67c68 100644 --- a/src/rust/engine/testutil/Cargo.toml +++ b/src/rust/engine/testutil/Cargo.toml @@ -7,9 +7,9 @@ publish = false [dependencies] async-stream = "0.3" -protos = { path = "../protos" } bytes = "1.0" -grpc_util = { path = "../grpc_util" } fs = { path = "../fs" } +grpc_util = { path = "../grpc_util" } hashing = { path = "../hashing" } prost = "0.9" +protos = { path = "../protos" } diff --git a/src/rust/engine/testutil/mock/Cargo.toml b/src/rust/engine/testutil/mock/Cargo.toml index 184adf7ec5d..d770d33f527 100644 --- a/src/rust/engine/testutil/mock/Cargo.toml +++ b/src/rust/engine/testutil/mock/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] async-stream = "0.3" -protos = { path = "../../protos" } bytes = "1.0" futures = "0.3" grpc_util = { path = "../../grpc_util" } @@ -17,6 +16,7 @@ log = "0.4" parking_lot = "0.12" prost = "0.9" prost-types = "0.9" +protos = { path = "../../protos" } testutil = { path = ".." } tokio = { version = "1.16", features = ["time"] } tonic = { version = "0.6" } diff --git a/src/rust/engine/testutil/mock/src/action_cache.rs b/src/rust/engine/testutil/mock/src/action_cache_service.rs similarity index 54% rename from src/rust/engine/testutil/mock/src/action_cache.rs rename to src/rust/engine/testutil/mock/src/action_cache_service.rs index 3bd92ea9720..9b5a74d500c 100644 --- a/src/rust/engine/testutil/mock/src/action_cache.rs +++ b/src/rust/engine/testutil/mock/src/action_cache_service.rs @@ -2,42 +2,69 @@ // Licensed under the Apache License, Version 2.0 (see LICENSE). use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; -use futures::FutureExt; -use grpc_util::hyper::AddrIncomingWithStream; -use hashing::{Digest, Fingerprint}; use parking_lot::Mutex; -use protos::gen::build::bazel::remote::execution::v2 as remexec; -use protos::require_digest; -use remexec::action_cache_server::{ActionCache, ActionCacheServer}; +use remexec::action_cache_server::ActionCache; use remexec::{ActionResult, GetActionResultRequest, UpdateActionResultRequest}; use tokio::time::sleep; -use tonic::transport::Server; use tonic::{Request, Response, Status}; -pub struct StubActionCache { +use hashing::{Digest, Fingerprint}; +use protos::gen::build::bazel::remote::execution::v2 as remexec; +use protos::require_digest; + +pub struct ActionCacheHandle { pub action_map: Arc>>, pub always_errors: Arc, - local_addr: SocketAddr, - shutdown_sender: Option>, } -impl Drop for StubActionCache { - fn drop(&mut self) { - self.shutdown_sender.take().unwrap().send(()).unwrap(); +impl ActionCacheHandle { + /// + /// Inserts the given action digest into the cache with the given outputs. + /// + pub fn insert( + &self, + action_digest: Digest, + exit_code: i32, + stdout_digest: Digest, + stderr_digest: Digest, + ) { + let action_result = ActionResult { + exit_code, + stdout_digest: Some(stdout_digest.into()), + stderr_digest: Some(stderr_digest.into()), + ..ActionResult::default() + }; + self + .action_map + .lock() + .insert(action_digest.hash, action_result); + } + + /// + /// Get the result for the given action digest. + /// + pub fn get(&self, action_digest: Digest) -> Option { + self.action_map.lock().get(&action_digest.hash).cloned() + } + + /// + /// Returns the number of cache entries in the cache. + /// + pub fn len(&self) -> usize { + self.action_map.lock().len() } } #[derive(Clone)] -struct ActionCacheResponder { - action_map: Arc>>, - always_errors: Arc, - read_delay: Duration, - write_delay: Duration, +pub(crate) struct ActionCacheResponder { + pub action_map: Arc>>, + pub always_errors: Arc, + pub read_delay: Duration, + pub write_delay: Duration, } #[tonic::async_trait] @@ -109,54 +136,3 @@ impl ActionCache for ActionCacheResponder { Ok(Response::new(action_result)) } } - -impl StubActionCache { - pub fn new() -> Result { - Self::new_with_delays(0, 0) - } - - pub fn new_with_delays(read_delay_ms: u64, write_delay_ms: u64) -> Result { - let action_map = Arc::new(Mutex::new(HashMap::new())); - let always_errors = Arc::new(AtomicBool::new(false)); - let responder = ActionCacheResponder { - action_map: action_map.clone(), - always_errors: always_errors.clone(), - read_delay: Duration::from_millis(read_delay_ms), - write_delay: Duration::from_millis(write_delay_ms), - }; - - let addr = "127.0.0.1:0" - .to_string() - .parse() - .expect("failed to parse IP address"); - let incoming = hyper::server::conn::AddrIncoming::bind(&addr).expect("failed to bind port"); - let local_addr = incoming.local_addr(); - let incoming = AddrIncomingWithStream(incoming); - - let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel(); - - tokio::spawn(async move { - let mut server = Server::builder(); - let router = server.add_service(ActionCacheServer::new(responder.clone())); - - router - .serve_with_incoming_shutdown(incoming, shutdown_receiver.map(drop)) - .await - .unwrap(); - }); - - Ok(StubActionCache { - action_map, - always_errors, - local_addr, - shutdown_sender: Some(shutdown_sender), - }) - } - - /// - /// The address on which this server is listening over insecure HTTP transport. - /// - pub fn address(&self) -> String { - format!("http://{}", self.local_addr) - } -} diff --git a/src/rust/engine/testutil/mock/src/cas.rs b/src/rust/engine/testutil/mock/src/cas.rs index e65e55318cf..53e5a370fde 100644 --- a/src/rust/engine/testutil/mock/src/cas.rs +++ b/src/rust/engine/testutil/mock/src/cas.rs @@ -1,43 +1,45 @@ use std::collections::HashMap; -use std::convert::TryInto; use std::net::SocketAddr; -use std::pin::Pin; +use std::sync::atomic::AtomicBool; use std::sync::Arc; +use std::time::Duration; -use bytes::{Bytes, BytesMut}; -use futures::stream::StreamExt; -use futures::{FutureExt, Stream}; +use bytes::Bytes; +use futures::FutureExt; use grpc_util::hyper::AddrIncomingWithStream; -use hashing::{Digest, Fingerprint}; +use hashing::Fingerprint; use parking_lot::Mutex; use protos::gen::build::bazel::remote::execution::v2 as remexec; -use protos::gen::build::bazel::semver::SemVer; -use protos::gen::google::bytestream::{ - byte_stream_server::ByteStream, byte_stream_server::ByteStreamServer, QueryWriteStatusRequest, - QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest, WriteResponse, -}; -use remexec::capabilities_server::{Capabilities, CapabilitiesServer}; -use remexec::content_addressable_storage_server::{ - ContentAddressableStorage, ContentAddressableStorageServer, -}; -use remexec::{ - BatchReadBlobsRequest, BatchReadBlobsResponse, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse, - CacheCapabilities, ExecutionCapabilities, FindMissingBlobsRequest, FindMissingBlobsResponse, - GetCapabilitiesRequest, GetTreeRequest, GetTreeResponse, ServerCapabilities, -}; +use protos::gen::google::bytestream::byte_stream_server::ByteStreamServer; +use remexec::action_cache_server::ActionCacheServer; +use remexec::capabilities_server::CapabilitiesServer; +use remexec::content_addressable_storage_server::ContentAddressableStorageServer; use testutil::data::{TestData, TestDirectory, TestTree}; -use tonic::metadata::{AsciiMetadataKey, KeyAndValueRef}; use tonic::transport::Server; -use tonic::{Request, Response, Status}; +use crate::action_cache_service::{ActionCacheHandle, ActionCacheResponder}; +use crate::cas_service::StubCASResponder; + +/// +/// Implements the: +/// * ContentAddressableStorage +/// * ActionCache +/// * Capabilities +/// ...gRPC APIs. /// -/// Implements the ContentAddressableStorage gRPC API, answering read requests with either known -/// content, NotFound for valid but unknown content, or InvalidArguments for bad arguments. +/// NB: You might expect that these services could be generically composed, but the +/// `tonic::{Server, Router}` builder pattern changes its type with each call to `add_service`, +/// making it very challenging to wrap. Instead, we statically compose them. /// pub struct StubCAS { + // CAS fields. + // TODO: These are inlined (rather than namespaced) for backwards compatibility. read_request_count: Arc>, pub write_message_sizes: Arc>>, pub blobs: Arc>>, + // AC fields. + pub action_cache: ActionCacheHandle, + // Generic server fields. local_addr: SocketAddr, shutdown_sender: Option>, } @@ -51,23 +53,29 @@ impl Drop for StubCAS { } pub struct StubCASBuilder { - always_errors: bool, + ac_always_errors: bool, + cas_always_errors: bool, chunk_size_bytes: Option, content: HashMap, port: Option, instance_name: Option, required_auth_token: Option, + ac_read_delay: Duration, + ac_write_delay: Duration, } impl StubCASBuilder { pub fn new() -> Self { StubCASBuilder { - always_errors: false, + ac_always_errors: false, + cas_always_errors: false, chunk_size_bytes: None, content: HashMap::new(), port: None, instance_name: None, required_auth_token: None, + ac_read_delay: Duration::from_millis(0), + ac_write_delay: Duration::from_millis(0), } } } @@ -111,8 +119,23 @@ impl StubCASBuilder { self } - pub fn always_errors(mut self) -> Self { - self.always_errors = true; + pub fn ac_always_errors(mut self) -> Self { + self.ac_always_errors = true; + self + } + + pub fn cas_always_errors(mut self) -> Self { + self.cas_always_errors = true; + self + } + + pub fn ac_read_delay(mut self, duration: Duration) -> Self { + self.ac_read_delay = duration; + self + } + + pub fn ac_write_delay(mut self, duration: Duration) -> Self { + self.ac_write_delay = duration; self } @@ -133,53 +156,29 @@ impl StubCASBuilder { } pub fn build(self) -> StubCAS { - StubCAS::new( - self.chunk_size_bytes.unwrap_or(1024), - self.content, - self.port.unwrap_or(0), - self.always_errors, - self.instance_name, - self.required_auth_token, - ) - } -} - -impl StubCAS { - pub fn builder() -> StubCASBuilder { - StubCASBuilder::new() - } - - /// - /// # Arguments - /// * `chunk_size_bytes` - The maximum number of bytes of content to include per streamed message. - /// Messages will saturate until the last one, which may be smaller than - /// this value. - /// If a negative value is given, all requests will receive an error. - /// * `blobs` - Known Fingerprints and their content responses. These are not checked - /// for correctness. - /// * `port` - The port for the CAS to listen to. - fn new( - chunk_size_bytes: usize, - blobs: HashMap, - port: u16, - always_errors: bool, - instance_name: Option, - required_auth_token: Option, - ) -> StubCAS { let read_request_count = Arc::new(Mutex::new(0)); let write_message_sizes = Arc::new(Mutex::new(Vec::new())); - let blobs = Arc::new(Mutex::new(blobs)); - let responder = StubCASResponder { - chunk_size_bytes, - instance_name, + let blobs = Arc::new(Mutex::new(self.content)); + let cas_responder = StubCASResponder { + chunk_size_bytes: self.chunk_size_bytes.unwrap_or(1024), + instance_name: self.instance_name, blobs: blobs.clone(), - always_errors, + always_errors: self.cas_always_errors, read_request_count: read_request_count.clone(), write_message_sizes: write_message_sizes.clone(), - required_auth_header: required_auth_token.map(|t| format!("Bearer {}", t)), + required_auth_header: self.required_auth_token.map(|t| format!("Bearer {}", t)), + }; + + let action_map = Arc::new(Mutex::new(HashMap::new())); + let ac_always_errors = Arc::new(AtomicBool::new(self.ac_always_errors)); + let ac_responder = ActionCacheResponder { + action_map: action_map.clone(), + always_errors: ac_always_errors.clone(), + read_delay: self.ac_read_delay, + write_delay: self.ac_write_delay, }; - let addr = format!("127.0.0.1:{}", port) + let addr = format!("127.0.0.1:{}", self.port.unwrap_or(0)) .parse() .expect("failed to parse IP address"); let incoming = hyper::server::conn::AddrIncoming::bind(&addr).expect("failed to bind port"); @@ -191,9 +190,10 @@ impl StubCAS { tokio::spawn(async move { let mut server = Server::builder(); let router = server - .add_service(ByteStreamServer::new(responder.clone())) - .add_service(ContentAddressableStorageServer::new(responder.clone())) - .add_service(CapabilitiesServer::new(responder)); + .add_service(ActionCacheServer::new(ac_responder.clone())) + .add_service(ByteStreamServer::new(cas_responder.clone())) + .add_service(ContentAddressableStorageServer::new(cas_responder.clone())) + .add_service(CapabilitiesServer::new(cas_responder)); router .serve_with_incoming_shutdown(incoming, shutdown_receiver.map(drop)) @@ -205,17 +205,27 @@ impl StubCAS { read_request_count, write_message_sizes, blobs, + action_cache: ActionCacheHandle { + action_map, + always_errors: ac_always_errors, + }, local_addr, shutdown_sender: Some(shutdown_sender), } } +} + +impl StubCAS { + pub fn builder() -> StubCASBuilder { + StubCASBuilder::new() + } pub fn empty() -> StubCAS { StubCAS::builder().build() } - pub fn always_errors() -> StubCAS { - StubCAS::builder().always_errors().build() + pub fn cas_always_errors() -> StubCAS { + StubCAS::builder().cas_always_errors().build() } /// @@ -228,664 +238,8 @@ impl StubCAS { pub fn read_request_count(&self) -> usize { *self.read_request_count.lock() } -} - -#[derive(Clone, Debug)] -pub struct StubCASResponder { - chunk_size_bytes: usize, - instance_name: Option, - blobs: Arc>>, - always_errors: bool, - required_auth_header: Option, - pub read_request_count: Arc>, - pub write_message_sizes: Arc>>, -} - -macro_rules! check_auth { - ($self:ident, $req:ident) => { - if let Some(ref required_auth_header) = $self.required_auth_header { - let auth_header = AsciiMetadataKey::from_static("authorization"); - let authorization_headers: Vec<_> = $req - .metadata() - .iter() - .filter_map(|kv| match kv { - KeyAndValueRef::Ascii(key, value) if key == auth_header => Some((key, value)), - _ => None, - }) - .map(|(_key, value)| value) - .collect(); - if authorization_headers.len() != 1 - || authorization_headers[0] != required_auth_header.as_bytes() - { - return Err(Status::unauthenticated(format!( - "Bad Authorization header; want {:?} got {:?}", - required_auth_header.as_bytes(), - authorization_headers - ))); - } - } - }; -} - -macro_rules! check_instance_name { - ($self:ident, $req:ident) => { - if $req.instance_name != $self.instance_name() { - return Err(Status::not_found(format!( - "Instance {} does not exist", - $req.instance_name - ))); - } - }; -} - -#[derive(Debug, Eq, PartialEq)] -struct ParsedWriteResourceName<'a> { - instance_name: &'a str, - _uuid: &'a str, - hash: &'a str, - size: usize, -} - -/// Parses a resource name of the form `{instance_name}/uploads/{uuid}/blobs/{hash}/{size}` into -/// a struct with references to the individual components of the resource name. The -/// `{instance_name}` may be blank (with no leading slash) as per REAPI specification. -fn parse_write_resource_name(resource: &str) -> Result { - if resource.is_empty() { - return Err("Missing resource name".to_owned()); - } - - // Parse the resource name into parts separated by slashes (/). - let parts: Vec<_> = resource.split('/').collect(); - - // Search for the `uploads` path component. - let uploads_index = match parts.iter().position(|p| *p == "uploads") { - Some(index) => index, - None => return Err("Malformed resource name: missing `uploads` component".to_owned()), - }; - let instance_parts = &parts[0..uploads_index]; - - if (parts.len() - uploads_index) < 5 { - return Err("Malformed resource name: not enough path components after `uploads`".to_owned()); - } - - if parts[uploads_index + 2] != "blobs" { - return Err("Malformed resource name: expected `blobs` component".to_owned()); - } - - let size = parts[uploads_index + 4] - .parse::() - .map_err(|_| "Malformed resource name: cannot parse size".to_owned())?; - - let instance_name = if instance_parts.is_empty() { - "" - } else { - let last_instance_name_index = - instance_parts.iter().map(|x| (*x).len()).sum::() + instance_parts.len() - 1; - &resource[0..last_instance_name_index] - }; - - Ok(ParsedWriteResourceName { - instance_name, - _uuid: parts[uploads_index + 1], - hash: parts[uploads_index + 3], - size, - }) -} - -#[derive(Debug, Eq, PartialEq)] -struct ParsedReadResourceName<'a> { - instance_name: &'a str, - hash: &'a str, - size: usize, -} - -/// `"{instance_name}/blobs/{hash}/{size}"` -fn parse_read_resource_name(resource: &str) -> Result { - if resource.is_empty() { - return Err("Missing resource name".to_owned()); - } - - // Parse the resource name into parts separated by slashes (/). - let parts: Vec<_> = resource.split('/').collect(); - - // Search for the `blobs` path component. - let blobs_index = match parts.iter().position(|p| *p == "blobs") { - Some(index) => index, - None => return Err("Malformed resource name: missing `blobs` component".to_owned()), - }; - let instance_parts = &parts[0..blobs_index]; - - if (parts.len() - blobs_index) < 3 { - return Err("Malformed resource name: not enough path components after `blobs`".to_owned()); - } - - let size = parts[blobs_index + 2] - .parse::() - .map_err(|_| "Malformed resource name: cannot parse size".to_owned())?; - - let instance_name = if instance_parts.is_empty() { - "" - } else { - let last_instance_name_index = - instance_parts.iter().map(|x| (*x).len()).sum::() + instance_parts.len() - 1; - &resource[0..last_instance_name_index] - }; - - Ok(ParsedReadResourceName { - instance_name, - hash: parts[blobs_index + 1], - size, - }) -} - -impl StubCASResponder { - fn instance_name(&self) -> String { - self.instance_name.clone().unwrap_or_default() - } - - fn read_internal(&self, req: &ReadRequest) -> Result, Status> { - let parsed_resource_name = parse_read_resource_name(&req.resource_name) - .map_err(|err| Status::invalid_argument(format!("Failed to parse resource name: {}", err)))?; - - let digest = parsed_resource_name.hash; - let fingerprint = Fingerprint::from_hex_string(digest) - .map_err(|e| Status::invalid_argument(format!("Bad digest {}: {}", digest, e)))?; - if self.always_errors { - return Err(Status::internal( - "StubCAS is configured to always fail".to_owned(), - )); - } - let blobs = self.blobs.lock(); - let maybe_bytes = blobs.get(&fingerprint); - match maybe_bytes { - Some(bytes) => Ok( - bytes - .chunks(self.chunk_size_bytes as usize) - .map(|b| ReadResponse { - data: bytes.slice_ref(b), - }) - .collect(), - ), - None => Err(Status::not_found(format!( - "Did not find digest {}", - fingerprint - ))), - } - } -} - -#[tonic::async_trait] -impl ByteStream for StubCASResponder { - type ReadStream = Pin> + Send + Sync>>; - - async fn read( - &self, - request: Request, - ) -> Result, Status> { - { - let mut request_count = self.read_request_count.lock(); - *request_count += 1; - } - check_auth!(self, request); - - let request = request.into_inner(); - - let stream_elements = self.read_internal(&request)?; - let stream = Box::pin(futures::stream::iter( - stream_elements.into_iter().map(Ok).collect::>(), - )); - Ok(Response::new(stream)) - } - - async fn write( - &self, - request: Request>, - ) -> Result, Status> { - check_auth!(self, request); - - let always_errors = self.always_errors; - let write_message_sizes = self.write_message_sizes.clone(); - let blobs = self.blobs.clone(); - - let mut stream = request.into_inner(); - - let mut maybe_resource_name = None; - let mut want_next_offset = 0; - let mut bytes = BytesMut::new(); - - while let Some(req_result) = stream.next().await { - let req = match req_result { - Ok(r) => r, - Err(e) => { - return Err(Status::invalid_argument(format!( - "Client sent an error: {}", - e - ))) - } - }; - - match maybe_resource_name { - None => maybe_resource_name = Some(req.resource_name.clone()), - Some(ref resource_name) => { - if *resource_name != req.resource_name { - return Err(Status::invalid_argument(format!( - "All resource names in stream must be the same. Got {} but earlier saw {}", - req.resource_name, resource_name - ))); - } - } - } - - if req.write_offset != want_next_offset { - return Err(Status::invalid_argument(format!( - "Missing chunk. Expected next offset {}, got next offset: {}", - want_next_offset, req.write_offset - ))); - } - - want_next_offset += req.data.len() as i64; - write_message_sizes.lock().push(req.data.len()); - bytes.extend_from_slice(&req.data); - } - - let bytes = bytes.freeze(); - - match maybe_resource_name { - None => Err(Status::invalid_argument( - "Stream saw no messages".to_owned(), - )), - Some(resource_name) => { - let parsed_resource_name = - parse_write_resource_name(&resource_name).map_err(Status::internal)?; - - if parsed_resource_name.instance_name != self.instance_name().as_str() { - return Err(Status::invalid_argument(format!( - "Bad instance name in resource name: expected={}, actual={}", - self.instance_name(), - parsed_resource_name.instance_name - ))); - } - - let fingerprint = match Fingerprint::from_hex_string(parsed_resource_name.hash) { - Ok(f) => f, - Err(err) => { - return Err(Status::invalid_argument(format!( - "Bad fingerprint in resource name: {}: {}", - parsed_resource_name.hash, err - ))); - } - }; - let size = parsed_resource_name.size; - if size != bytes.len() { - return Err(Status::invalid_argument(format!( - "Size was incorrect: resource name said size={} but got {}", - size, - bytes.len() - ))); - } - - if always_errors { - return Err(Status::invalid_argument( - "StubCAS is configured to always fail".to_owned(), - )); - } - - { - let mut blobs = blobs.lock(); - blobs.insert(fingerprint, bytes); - } - - let response = WriteResponse { - committed_size: size as i64, - }; - Ok(Response::new(response)) - } - } - } - - async fn query_write_status( - &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("".to_owned())) - } -} - -#[tonic::async_trait] -impl ContentAddressableStorage for StubCASResponder { - async fn find_missing_blobs( - &self, - request: Request, - ) -> Result, Status> { - check_auth!(self, request); - - if self.always_errors { - return Err(Status::internal( - "StubCAS is configured to always fail".to_owned(), - )); - } - - let request = request.into_inner(); - - check_instance_name!(self, request); - - let blobs = self.blobs.lock(); - let mut response = FindMissingBlobsResponse::default(); - for digest in request.blob_digests { - let hashing_digest_result: Result = digest.try_into(); - let hashing_digest = hashing_digest_result.expect("Bad digest"); - if !blobs.contains_key(&hashing_digest.hash) { - response.missing_blob_digests.push(hashing_digest.into()) - } - } - Ok(Response::new(response)) - } - - async fn batch_update_blobs( - &self, - request: Request, - ) -> Result, Status> { - check_auth!(self, request); - - if self.always_errors { - return Err(Status::invalid_argument( - "StubCAS is configured to always fail".to_owned(), - )); - } - - let request = request.into_inner(); - - check_instance_name!(self, request); - - let mut responses = Vec::new(); - let mut blobs = self.blobs.lock(); - - fn write_blob( - request: remexec::batch_update_blobs_request::Request, - blobs: &mut HashMap, - ) -> Status { - let digest = match request.digest { - Some(d) => d, - None => return Status::invalid_argument("digest not set in batch update request"), - }; - - let fingerprint = match Fingerprint::from_hex_string(&digest.hash) { - Ok(f) => f, - Err(err) => { - return Status::invalid_argument(format!("Bad fingerprint: {}: {}", &digest.hash, err)); - } - }; - - if request.data.len() != digest.size_bytes as usize { - return Status::invalid_argument(format!( - "Size was incorrect: digest size is {} but got {} from data", - digest.size_bytes, - request.data.len() - )); - } - - blobs.insert(fingerprint, request.data); - Status::ok("") - } - - for blob_request in request.requests { - let digest = blob_request.digest.clone(); - self - .write_message_sizes - .lock() - .push(blob_request.data.len()); - let status = write_blob(blob_request, &mut blobs); - responses.push(remexec::batch_update_blobs_response::Response { - digest, - status: Some(protos::gen::google::rpc::Status { - code: status.code() as i32, - message: status.message().to_string(), - ..protos::gen::google::rpc::Status::default() - }), - }) - } - - Ok(Response::new(BatchUpdateBlobsResponse { responses })) - } - - async fn batch_read_blobs( - &self, - request: Request, - ) -> Result, Status> { - check_auth!(self, request); - - if self.always_errors { - return Err(Status::invalid_argument( - "StubCAS is configured to always fail".to_owned(), - )); - } - - let request = request.into_inner(); - - check_instance_name!(self, request); - - let mut responses = Vec::new(); - let blobs = self.blobs.lock(); - - fn read_blob( - digest: remexec::Digest, - blobs: &HashMap, - ) -> (Option, Status) { - let fingerprint = match Fingerprint::from_hex_string(&digest.hash) { - Ok(f) => f, - Err(err) => { - return ( - None, - Status::invalid_argument(format!("Bad fingerprint: {}: {}", &digest.hash, err)), - ); - } - }; - - match blobs.get(&fingerprint) { - Some(data) => { - if data.len() == digest.size_bytes as usize { - (Some(data.clone()), Status::ok("")) - } else { - ( - None, - Status::invalid_argument(format!( - "Size was incorrect: digest size is {} but got {} from data", - digest.size_bytes, - data.len() - )), - ) - } - } - None => (None, Status::not_found("")), - } - } - - for digest in request.digests { - let (data_opt, status) = read_blob(digest.clone(), &blobs); - responses.push(remexec::batch_read_blobs_response::Response { - digest: Some(digest), - data: data_opt.unwrap_or_else(Bytes::new), - status: Some(protos::gen::google::rpc::Status { - code: status.code() as i32, - message: status.message().to_string(), - ..protos::gen::google::rpc::Status::default() - }), - }); - } - - Ok(Response::new(remexec::BatchReadBlobsResponse { responses })) - } - - type GetTreeStream = tonic::codec::Streaming; - - async fn get_tree( - &self, - _: Request, - ) -> Result, Status> { - Err(Status::unimplemented("".to_owned())) - } -} - -#[tonic::async_trait] -impl Capabilities for StubCASResponder { - async fn get_capabilities( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - check_instance_name!(self, request); - - let response = ServerCapabilities { - cache_capabilities: Some(CacheCapabilities { - digest_function: vec![remexec::digest_function::Value::Sha256 as i32], - max_batch_total_size_bytes: 0, - ..CacheCapabilities::default() - }), - execution_capabilities: Some(ExecutionCapabilities { - digest_function: remexec::digest_function::Value::Sha256 as i32, - exec_enabled: true, - ..ExecutionCapabilities::default() - }), - high_api_version: Some(SemVer { - major: 2, - minor: 999, - ..SemVer::default() - }), - ..ServerCapabilities::default() - }; - - Ok(Response::new(response)) - } -} - -#[cfg(test)] -mod tests { - use super::{ - parse_read_resource_name, parse_write_resource_name, ParsedReadResourceName, - ParsedWriteResourceName, - }; - - #[test] - fn parse_write_resource_name_correctly() { - let result = parse_write_resource_name("main/uploads/uuid-12345/blobs/abc123/12").unwrap(); - assert_eq!( - result, - ParsedWriteResourceName { - instance_name: "main", - _uuid: "uuid-12345", - hash: "abc123", - size: 12, - } - ); - - let result = parse_write_resource_name("uploads/uuid-12345/blobs/abc123/12").unwrap(); - assert_eq!( - result, - ParsedWriteResourceName { - instance_name: "", - _uuid: "uuid-12345", - hash: "abc123", - size: 12, - } - ); - - let result = parse_write_resource_name("a/b/c/uploads/uuid-12345/blobs/abc123/12").unwrap(); - assert_eq!( - result, - ParsedWriteResourceName { - instance_name: "a/b/c", - _uuid: "uuid-12345", - hash: "abc123", - size: 12, - } - ); - - // extra components after the size are accepted - let result = - parse_write_resource_name("a/b/c/uploads/uuid-12345/blobs/abc123/12/extra/stuff").unwrap(); - assert_eq!( - result, - ParsedWriteResourceName { - instance_name: "a/b/c", - _uuid: "uuid-12345", - hash: "abc123", - size: 12, - } - ); - } - - #[test] - fn parse_write_resource_name_errors_as_expected() { - // - let err = parse_write_resource_name("").unwrap_err(); - assert_eq!(err, "Missing resource name"); - - let err = parse_write_resource_name("main/uuid-12345/blobs/abc123/12").unwrap_err(); - assert_eq!(err, "Malformed resource name: missing `uploads` component"); - - let err = parse_write_resource_name("main/uploads/uuid-12345/abc123/12").unwrap_err(); - assert_eq!( - err, - "Malformed resource name: not enough path components after `uploads`" - ); - - let err = parse_write_resource_name("main/uploads/uuid-12345/abc123/12/foo").unwrap_err(); - assert_eq!(err, "Malformed resource name: expected `blobs` component"); - - // negative size should be rejected - let err = parse_write_resource_name("main/uploads/uuid-12345/blobs/abc123/-12").unwrap_err(); - assert_eq!(err, "Malformed resource name: cannot parse size"); - } - - #[test] - fn parse_read_resource_name_correctly() { - let result = parse_read_resource_name("main/blobs/abc123/12").unwrap(); - assert_eq!( - result, - ParsedReadResourceName { - instance_name: "main", - hash: "abc123", - size: 12, - } - ); - - let result = parse_read_resource_name("blobs/abc123/12").unwrap(); - assert_eq!( - result, - ParsedReadResourceName { - instance_name: "", - hash: "abc123", - size: 12, - } - ); - - let result = parse_read_resource_name("a/b/c/blobs/abc123/12").unwrap(); - assert_eq!( - result, - ParsedReadResourceName { - instance_name: "a/b/c", - hash: "abc123", - size: 12, - } - ); - } - - #[test] - fn parse_read_resource_name_errors_as_expected() { - let err = parse_read_resource_name("").unwrap_err(); - assert_eq!(err, "Missing resource name"); - - let err = parse_read_resource_name("main/abc123/12").unwrap_err(); - assert_eq!(err, "Malformed resource name: missing `blobs` component"); - - let err = parse_read_resource_name("main/blobs/12").unwrap_err(); - assert_eq!( - err, - "Malformed resource name: not enough path components after `blobs`" - ); - // negative size should be rejected - let err = parse_read_resource_name("main/blobs/abc123/-12").unwrap_err(); - assert_eq!(err, "Malformed resource name: cannot parse size"); + pub fn remove(&self, fingerprint: Fingerprint) -> bool { + self.blobs.lock().remove(&fingerprint).is_some() } } diff --git a/src/rust/engine/testutil/mock/src/cas_service.rs b/src/rust/engine/testutil/mock/src/cas_service.rs new file mode 100644 index 00000000000..de855a1ae36 --- /dev/null +++ b/src/rust/engine/testutil/mock/src/cas_service.rs @@ -0,0 +1,685 @@ +use std::collections::HashMap; +use std::convert::TryInto; +use std::pin::Pin; +use std::sync::Arc; + +use bytes::{Bytes, BytesMut}; +use futures::stream::StreamExt; +use futures::Stream; +use hashing::{Digest, Fingerprint}; +use parking_lot::Mutex; +use protos::gen::build::bazel::remote::execution::v2 as remexec; +use protos::gen::build::bazel::semver::SemVer; +use protos::gen::google::bytestream::{ + byte_stream_server::ByteStream, QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, + ReadResponse, WriteRequest, WriteResponse, +}; +use remexec::capabilities_server::Capabilities; +use remexec::content_addressable_storage_server::ContentAddressableStorage; +use remexec::{ + BatchReadBlobsRequest, BatchReadBlobsResponse, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse, + CacheCapabilities, ExecutionCapabilities, FindMissingBlobsRequest, FindMissingBlobsResponse, + GetCapabilitiesRequest, GetTreeRequest, GetTreeResponse, ServerCapabilities, +}; +use tonic::metadata::{AsciiMetadataKey, KeyAndValueRef}; +use tonic::{Request, Response, Status}; + +#[derive(Clone, Debug)] +pub(crate) struct StubCASResponder { + pub chunk_size_bytes: usize, + pub instance_name: Option, + pub blobs: Arc>>, + pub always_errors: bool, + pub required_auth_header: Option, + pub read_request_count: Arc>, + pub write_message_sizes: Arc>>, +} + +macro_rules! check_auth { + ($self:ident, $req:ident) => { + if let Some(ref required_auth_header) = $self.required_auth_header { + let auth_header = AsciiMetadataKey::from_static("authorization"); + let authorization_headers: Vec<_> = $req + .metadata() + .iter() + .filter_map(|kv| match kv { + KeyAndValueRef::Ascii(key, value) if key == auth_header => Some((key, value)), + _ => None, + }) + .map(|(_key, value)| value) + .collect(); + if authorization_headers.len() != 1 + || authorization_headers[0] != required_auth_header.as_bytes() + { + return Err(Status::unauthenticated(format!( + "Bad Authorization header; want {:?} got {:?}", + required_auth_header.as_bytes(), + authorization_headers + ))); + } + } + }; +} + +macro_rules! check_instance_name { + ($self:ident, $req:ident) => { + if $req.instance_name != $self.instance_name() { + return Err(Status::not_found(format!( + "Instance {} does not exist", + $req.instance_name + ))); + } + }; +} + +#[derive(Debug, Eq, PartialEq)] +struct ParsedWriteResourceName<'a> { + instance_name: &'a str, + _uuid: &'a str, + hash: &'a str, + size: usize, +} + +/// Parses a resource name of the form `{instance_name}/uploads/{uuid}/blobs/{hash}/{size}` into +/// a struct with references to the individual components of the resource name. The +/// `{instance_name}` may be blank (with no leading slash) as per REAPI specification. +fn parse_write_resource_name(resource: &str) -> Result { + if resource.is_empty() { + return Err("Missing resource name".to_owned()); + } + + // Parse the resource name into parts separated by slashes (/). + let parts: Vec<_> = resource.split('/').collect(); + + // Search for the `uploads` path component. + let uploads_index = match parts.iter().position(|p| *p == "uploads") { + Some(index) => index, + None => return Err("Malformed resource name: missing `uploads` component".to_owned()), + }; + let instance_parts = &parts[0..uploads_index]; + + if (parts.len() - uploads_index) < 5 { + return Err("Malformed resource name: not enough path components after `uploads`".to_owned()); + } + + if parts[uploads_index + 2] != "blobs" { + return Err("Malformed resource name: expected `blobs` component".to_owned()); + } + + let size = parts[uploads_index + 4] + .parse::() + .map_err(|_| "Malformed resource name: cannot parse size".to_owned())?; + + let instance_name = if instance_parts.is_empty() { + "" + } else { + let last_instance_name_index = + instance_parts.iter().map(|x| (*x).len()).sum::() + instance_parts.len() - 1; + &resource[0..last_instance_name_index] + }; + + Ok(ParsedWriteResourceName { + instance_name, + _uuid: parts[uploads_index + 1], + hash: parts[uploads_index + 3], + size, + }) +} + +#[derive(Debug, Eq, PartialEq)] +struct ParsedReadResourceName<'a> { + instance_name: &'a str, + hash: &'a str, + size: usize, +} + +/// `"{instance_name}/blobs/{hash}/{size}"` +fn parse_read_resource_name(resource: &str) -> Result { + if resource.is_empty() { + return Err("Missing resource name".to_owned()); + } + + // Parse the resource name into parts separated by slashes (/). + let parts: Vec<_> = resource.split('/').collect(); + + // Search for the `blobs` path component. + let blobs_index = match parts.iter().position(|p| *p == "blobs") { + Some(index) => index, + None => return Err("Malformed resource name: missing `blobs` component".to_owned()), + }; + let instance_parts = &parts[0..blobs_index]; + + if (parts.len() - blobs_index) < 3 { + return Err("Malformed resource name: not enough path components after `blobs`".to_owned()); + } + + let size = parts[blobs_index + 2] + .parse::() + .map_err(|_| "Malformed resource name: cannot parse size".to_owned())?; + + let instance_name = if instance_parts.is_empty() { + "" + } else { + let last_instance_name_index = + instance_parts.iter().map(|x| (*x).len()).sum::() + instance_parts.len() - 1; + &resource[0..last_instance_name_index] + }; + + Ok(ParsedReadResourceName { + instance_name, + hash: parts[blobs_index + 1], + size, + }) +} + +impl StubCASResponder { + fn instance_name(&self) -> String { + self.instance_name.clone().unwrap_or_default() + } + + fn read_internal(&self, req: &ReadRequest) -> Result, Status> { + let parsed_resource_name = parse_read_resource_name(&req.resource_name) + .map_err(|err| Status::invalid_argument(format!("Failed to parse resource name: {}", err)))?; + + let digest = parsed_resource_name.hash; + let fingerprint = Fingerprint::from_hex_string(digest) + .map_err(|e| Status::invalid_argument(format!("Bad digest {}: {}", digest, e)))?; + if self.always_errors { + return Err(Status::internal( + "StubCAS is configured to always fail".to_owned(), + )); + } + let blobs = self.blobs.lock(); + let maybe_bytes = blobs.get(&fingerprint); + match maybe_bytes { + Some(bytes) => Ok( + bytes + .chunks(self.chunk_size_bytes as usize) + .map(|b| ReadResponse { + data: bytes.slice_ref(b), + }) + .collect(), + ), + None => Err(Status::not_found(format!( + "Did not find digest {}", + fingerprint + ))), + } + } +} + +#[tonic::async_trait] +impl ByteStream for StubCASResponder { + type ReadStream = Pin> + Send + Sync>>; + + async fn read( + &self, + request: Request, + ) -> Result, Status> { + { + let mut request_count = self.read_request_count.lock(); + *request_count += 1; + } + check_auth!(self, request); + + let request = request.into_inner(); + + let stream_elements = self.read_internal(&request)?; + let stream = Box::pin(futures::stream::iter( + stream_elements.into_iter().map(Ok).collect::>(), + )); + Ok(Response::new(stream)) + } + + async fn write( + &self, + request: Request>, + ) -> Result, Status> { + check_auth!(self, request); + + let always_errors = self.always_errors; + let write_message_sizes = self.write_message_sizes.clone(); + let blobs = self.blobs.clone(); + + let mut stream = request.into_inner(); + + let mut maybe_resource_name = None; + let mut want_next_offset = 0; + let mut bytes = BytesMut::new(); + + while let Some(req_result) = stream.next().await { + let req = match req_result { + Ok(r) => r, + Err(e) => { + return Err(Status::invalid_argument(format!( + "Client sent an error: {}", + e + ))) + } + }; + + match maybe_resource_name { + None => maybe_resource_name = Some(req.resource_name.clone()), + Some(ref resource_name) => { + if *resource_name != req.resource_name { + return Err(Status::invalid_argument(format!( + "All resource names in stream must be the same. Got {} but earlier saw {}", + req.resource_name, resource_name + ))); + } + } + } + + if req.write_offset != want_next_offset { + return Err(Status::invalid_argument(format!( + "Missing chunk. Expected next offset {}, got next offset: {}", + want_next_offset, req.write_offset + ))); + } + + want_next_offset += req.data.len() as i64; + write_message_sizes.lock().push(req.data.len()); + bytes.extend_from_slice(&req.data); + } + + let bytes = bytes.freeze(); + + match maybe_resource_name { + None => Err(Status::invalid_argument( + "Stream saw no messages".to_owned(), + )), + Some(resource_name) => { + let parsed_resource_name = + parse_write_resource_name(&resource_name).map_err(Status::internal)?; + + if parsed_resource_name.instance_name != self.instance_name().as_str() { + return Err(Status::invalid_argument(format!( + "Bad instance name in resource name: expected={}, actual={}", + self.instance_name(), + parsed_resource_name.instance_name + ))); + } + + let fingerprint = match Fingerprint::from_hex_string(parsed_resource_name.hash) { + Ok(f) => f, + Err(err) => { + return Err(Status::invalid_argument(format!( + "Bad fingerprint in resource name: {}: {}", + parsed_resource_name.hash, err + ))); + } + }; + let size = parsed_resource_name.size; + if size != bytes.len() { + return Err(Status::invalid_argument(format!( + "Size was incorrect: resource name said size={} but got {}", + size, + bytes.len() + ))); + } + + if always_errors { + return Err(Status::invalid_argument( + "StubCAS is configured to always fail".to_owned(), + )); + } + + { + let mut blobs = blobs.lock(); + blobs.insert(fingerprint, bytes); + } + + let response = WriteResponse { + committed_size: size as i64, + }; + Ok(Response::new(response)) + } + } + } + + async fn query_write_status( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("".to_owned())) + } +} + +#[tonic::async_trait] +impl ContentAddressableStorage for StubCASResponder { + async fn find_missing_blobs( + &self, + request: Request, + ) -> Result, Status> { + check_auth!(self, request); + + if self.always_errors { + return Err(Status::internal( + "StubCAS is configured to always fail".to_owned(), + )); + } + + let request = request.into_inner(); + + check_instance_name!(self, request); + + let blobs = self.blobs.lock(); + let mut response = FindMissingBlobsResponse::default(); + for digest in request.blob_digests { + let hashing_digest_result: Result = digest.try_into(); + let hashing_digest = hashing_digest_result.expect("Bad digest"); + if !blobs.contains_key(&hashing_digest.hash) { + response.missing_blob_digests.push(hashing_digest.into()) + } + } + Ok(Response::new(response)) + } + + async fn batch_update_blobs( + &self, + request: Request, + ) -> Result, Status> { + check_auth!(self, request); + + if self.always_errors { + return Err(Status::invalid_argument( + "StubCAS is configured to always fail".to_owned(), + )); + } + + let request = request.into_inner(); + + check_instance_name!(self, request); + + let mut responses = Vec::new(); + let mut blobs = self.blobs.lock(); + + fn write_blob( + request: remexec::batch_update_blobs_request::Request, + blobs: &mut HashMap, + ) -> Status { + let digest = match request.digest { + Some(d) => d, + None => return Status::invalid_argument("digest not set in batch update request"), + }; + + let fingerprint = match Fingerprint::from_hex_string(&digest.hash) { + Ok(f) => f, + Err(err) => { + return Status::invalid_argument(format!("Bad fingerprint: {}: {}", &digest.hash, err)); + } + }; + + if request.data.len() != digest.size_bytes as usize { + return Status::invalid_argument(format!( + "Size was incorrect: digest size is {} but got {} from data", + digest.size_bytes, + request.data.len() + )); + } + + blobs.insert(fingerprint, request.data); + Status::ok("") + } + + for blob_request in request.requests { + let digest = blob_request.digest.clone(); + self + .write_message_sizes + .lock() + .push(blob_request.data.len()); + let status = write_blob(blob_request, &mut blobs); + responses.push(remexec::batch_update_blobs_response::Response { + digest, + status: Some(protos::gen::google::rpc::Status { + code: status.code() as i32, + message: status.message().to_string(), + ..protos::gen::google::rpc::Status::default() + }), + }) + } + + Ok(Response::new(BatchUpdateBlobsResponse { responses })) + } + + async fn batch_read_blobs( + &self, + request: Request, + ) -> Result, Status> { + check_auth!(self, request); + + if self.always_errors { + return Err(Status::invalid_argument( + "StubCAS is configured to always fail".to_owned(), + )); + } + + let request = request.into_inner(); + + check_instance_name!(self, request); + + let mut responses = Vec::new(); + let blobs = self.blobs.lock(); + + fn read_blob( + digest: remexec::Digest, + blobs: &HashMap, + ) -> (Option, Status) { + let fingerprint = match Fingerprint::from_hex_string(&digest.hash) { + Ok(f) => f, + Err(err) => { + return ( + None, + Status::invalid_argument(format!("Bad fingerprint: {}: {}", &digest.hash, err)), + ); + } + }; + + match blobs.get(&fingerprint) { + Some(data) => { + if data.len() == digest.size_bytes as usize { + (Some(data.clone()), Status::ok("")) + } else { + ( + None, + Status::invalid_argument(format!( + "Size was incorrect: digest size is {} but got {} from data", + digest.size_bytes, + data.len() + )), + ) + } + } + None => (None, Status::not_found("")), + } + } + + for digest in request.digests { + let (data_opt, status) = read_blob(digest.clone(), &blobs); + responses.push(remexec::batch_read_blobs_response::Response { + digest: Some(digest), + data: data_opt.unwrap_or_else(Bytes::new), + status: Some(protos::gen::google::rpc::Status { + code: status.code() as i32, + message: status.message().to_string(), + ..protos::gen::google::rpc::Status::default() + }), + }); + } + + Ok(Response::new(remexec::BatchReadBlobsResponse { responses })) + } + + type GetTreeStream = tonic::codec::Streaming; + + async fn get_tree( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("".to_owned())) + } +} + +#[tonic::async_trait] +impl Capabilities for StubCASResponder { + async fn get_capabilities( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + check_instance_name!(self, request); + + let response = ServerCapabilities { + cache_capabilities: Some(CacheCapabilities { + digest_function: vec![remexec::digest_function::Value::Sha256 as i32], + max_batch_total_size_bytes: 0, + ..CacheCapabilities::default() + }), + execution_capabilities: Some(ExecutionCapabilities { + digest_function: remexec::digest_function::Value::Sha256 as i32, + exec_enabled: true, + ..ExecutionCapabilities::default() + }), + high_api_version: Some(SemVer { + major: 2, + minor: 999, + ..SemVer::default() + }), + ..ServerCapabilities::default() + }; + + Ok(Response::new(response)) + } +} + +#[cfg(test)] +mod tests { + use super::{ + parse_read_resource_name, parse_write_resource_name, ParsedReadResourceName, + ParsedWriteResourceName, + }; + + #[test] + fn parse_write_resource_name_correctly() { + let result = parse_write_resource_name("main/uploads/uuid-12345/blobs/abc123/12").unwrap(); + assert_eq!( + result, + ParsedWriteResourceName { + instance_name: "main", + _uuid: "uuid-12345", + hash: "abc123", + size: 12, + } + ); + + let result = parse_write_resource_name("uploads/uuid-12345/blobs/abc123/12").unwrap(); + assert_eq!( + result, + ParsedWriteResourceName { + instance_name: "", + _uuid: "uuid-12345", + hash: "abc123", + size: 12, + } + ); + + let result = parse_write_resource_name("a/b/c/uploads/uuid-12345/blobs/abc123/12").unwrap(); + assert_eq!( + result, + ParsedWriteResourceName { + instance_name: "a/b/c", + _uuid: "uuid-12345", + hash: "abc123", + size: 12, + } + ); + + // extra components after the size are accepted + let result = + parse_write_resource_name("a/b/c/uploads/uuid-12345/blobs/abc123/12/extra/stuff").unwrap(); + assert_eq!( + result, + ParsedWriteResourceName { + instance_name: "a/b/c", + _uuid: "uuid-12345", + hash: "abc123", + size: 12, + } + ); + } + + #[test] + fn parse_write_resource_name_errors_as_expected() { + // + let err = parse_write_resource_name("").unwrap_err(); + assert_eq!(err, "Missing resource name"); + + let err = parse_write_resource_name("main/uuid-12345/blobs/abc123/12").unwrap_err(); + assert_eq!(err, "Malformed resource name: missing `uploads` component"); + + let err = parse_write_resource_name("main/uploads/uuid-12345/abc123/12").unwrap_err(); + assert_eq!( + err, + "Malformed resource name: not enough path components after `uploads`" + ); + + let err = parse_write_resource_name("main/uploads/uuid-12345/abc123/12/foo").unwrap_err(); + assert_eq!(err, "Malformed resource name: expected `blobs` component"); + + // negative size should be rejected + let err = parse_write_resource_name("main/uploads/uuid-12345/blobs/abc123/-12").unwrap_err(); + assert_eq!(err, "Malformed resource name: cannot parse size"); + } + + #[test] + fn parse_read_resource_name_correctly() { + let result = parse_read_resource_name("main/blobs/abc123/12").unwrap(); + assert_eq!( + result, + ParsedReadResourceName { + instance_name: "main", + hash: "abc123", + size: 12, + } + ); + + let result = parse_read_resource_name("blobs/abc123/12").unwrap(); + assert_eq!( + result, + ParsedReadResourceName { + instance_name: "", + hash: "abc123", + size: 12, + } + ); + + let result = parse_read_resource_name("a/b/c/blobs/abc123/12").unwrap(); + assert_eq!( + result, + ParsedReadResourceName { + instance_name: "a/b/c", + hash: "abc123", + size: 12, + } + ); + } + + #[test] + fn parse_read_resource_name_errors_as_expected() { + let err = parse_read_resource_name("").unwrap_err(); + assert_eq!(err, "Missing resource name"); + + let err = parse_read_resource_name("main/abc123/12").unwrap_err(); + assert_eq!(err, "Malformed resource name: missing `blobs` component"); + + let err = parse_read_resource_name("main/blobs/12").unwrap_err(); + assert_eq!( + err, + "Malformed resource name: not enough path components after `blobs`" + ); + + // negative size should be rejected + let err = parse_read_resource_name("main/blobs/abc123/-12").unwrap_err(); + assert_eq!(err, "Malformed resource name: cannot parse size"); + } +} diff --git a/src/rust/engine/testutil/mock/src/lib.rs b/src/rust/engine/testutil/mock/src/lib.rs index d83d336ced2..d57eb004e90 100644 --- a/src/rust/engine/testutil/mock/src/lib.rs +++ b/src/rust/engine/testutil/mock/src/lib.rs @@ -25,10 +25,10 @@ // Arc can be more clear than needing to grok Orderings: #![allow(clippy::mutex_atomic)] -mod action_cache; +mod action_cache_service; mod cas; +mod cas_service; pub mod execution_server; -pub use crate::action_cache::StubActionCache; pub use crate::cas::{StubCAS, StubCASBuilder}; pub use crate::execution_server::MockExecution; diff --git a/src/rust/engine/workunit_store/src/metrics.rs b/src/rust/engine/workunit_store/src/metrics.rs index cb32e058338..b29a7e43e20 100644 --- a/src/rust/engine/workunit_store/src/metrics.rs +++ b/src/rust/engine/workunit_store/src/metrics.rs @@ -53,6 +53,8 @@ pub enum Metric { RemoteStoreBlobBytesDownloaded, /// Total number of bytes of blobs uploaded to a remote CAS. RemoteStoreBlobBytesUploaded, + /// Number of times that we backtracked due to missing digests. + BacktrackAttempts, } impl Metric { diff --git a/tests/python/pants_test/integration/remote_cache_integration_test.py b/tests/python/pants_test/integration/remote_cache_integration_test.py index 6c1ed96f12e..a8fbcfcaf7e 100644 --- a/tests/python/pants_test/integration/remote_cache_integration_test.py +++ b/tests/python/pants_test/integration/remote_cache_integration_test.py @@ -1,41 +1,55 @@ # Copyright 2021 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +from __future__ import annotations + +import time + +from pants.engine.fs import Digest, DigestContents, DigestEntries, FileDigest, FileEntry from pants.engine.internals.native_engine import PyExecutor, PyStubCAS +from pants.engine.process import Process, ProcessResult +from pants.engine.rules import Get, rule from pants.option.global_options import RemoteCacheWarningsBehavior -from pants.option.scope import GLOBAL_SCOPE_CONFIG_SECTION from pants.testutil.pants_integration_test import run_pants +from pants.testutil.rule_runner import QueryRule, RuleRunner, logging +from pants.util.logging import LogLevel + +def remote_cache_args( + store_address: str, + warnings_behavior: RemoteCacheWarningsBehavior = RemoteCacheWarningsBehavior.backoff, +) -> list[str]: + # NB: Our options code expects `grpc://`, which it will then convert back to + # `http://` before sending over FFI. + store_address = store_address.replace("http://", "grpc://") + return [ + "--remote-cache-read", + "--remote-cache-write", + f"--remote-cache-warnings={warnings_behavior.value}", + f"--remote-store-address={store_address}", + ] -def test_warns_on_remote_cache_errors(): + +def test_warns_on_remote_cache_errors() -> None: executor = PyExecutor(core_threads=2, max_threads=4) - cas = PyStubCAS.builder().always_errors().build(executor) + cas = PyStubCAS.builder().ac_always_errors().cas_always_errors().build(executor) def run(behavior: RemoteCacheWarningsBehavior) -> str: pants_run = run_pants( [ "--backend-packages=['pants.backend.python']", "--no-dynamic-ui", + *remote_cache_args(cas.address, behavior), "package", "testprojects/src/python/hello/main:main", ], use_pantsd=False, - config={ - GLOBAL_SCOPE_CONFIG_SECTION: { - "remote_cache_read": True, - "remote_cache_write": True, - "remote_cache_warnings": behavior.value, - # NB: Our options code expects `grpc://`, which it will then convert back to - # `http://` before sending over FFI. - "remote_store_address": cas.address.replace("http://", "grpc://"), - } - }, ) pants_run.assert_success() return pants_run.stderr def read_err(i: int) -> str: - return f"Failed to read from remote cache ({i} occurrences so far): Unimplemented" + return f"Failed to read from remote cache ({i} occurrences so far): Unavailable" def write_err(i: int) -> str: return ( @@ -72,3 +86,66 @@ def write_err(i: int) -> str: assert err in backoff_result for err in [third_read_err, third_write_err]: assert err not in backoff_result + + +class ProcessOutputEntries(DigestEntries): + pass + + +@rule +async def entries_from_process(process_result: ProcessResult) -> ProcessOutputEntries: + # DigestEntries won't actually load file content, so we need to force it with DigestContents. + _ = await Get(DigestContents, Digest, process_result.output_digest) + return ProcessOutputEntries(await Get(DigestEntries, Digest, process_result.output_digest)) + + +@logging +def test_lazy_fetch_backtracking() -> None: + executor = PyExecutor(core_threads=2, max_threads=4) + cas = PyStubCAS.builder().build(executor) + + def run() -> tuple[FileDigest, dict[str, int]]: + # Use an isolated store to ensure that the only content is in the remote/stub cache. + rule_runner = RuleRunner( + rules=[entries_from_process, QueryRule(ProcessOutputEntries, [Process])], + isolated_local_store=True, + bootstrap_args=[ + "--no-remote-cache-eager-fetch", + "--no-local-cache", + *remote_cache_args(cas.address), + ], + ) + entries = rule_runner.request( + ProcessOutputEntries, + [ + Process( + ["/bin/bash", "-c", "sleep 1; echo content > file.txt"], + description="Create file.txt", + output_files=["file.txt"], + level=LogLevel.INFO, + ) + ], + ) + assert len(entries) == 1 + entry = entries[0] + assert isinstance(entry, FileEntry) + + # Wait for any async cache writes to complete. + time.sleep(1) + return entry.file_digest, rule_runner.scheduler.get_metrics() + + # Run once to populate the remote cache, and validate that there is one entry afterwards. + assert cas.action_cache_len() == 0 + file_digest1, metrics1 = run() + assert cas.action_cache_len() == 1 + assert metrics1["remote_cache_requests"] == 1 + assert metrics1["remote_cache_requests_uncached"] == 1 + + # Then, remove the content from the remote store and run again. + assert cas.remove(file_digest1) + file_digest2, metrics2 = run() + assert file_digest1 == file_digest2 + # Validate both that we hit the cache, and that we backtracked to actually run the process. + assert metrics2["remote_cache_requests"] == 1 + assert metrics2["remote_cache_requests_cached"] == 1 + assert metrics2["backtrack_attempts"] == 1