diff --git a/Cargo.lock b/Cargo.lock index c3f7c26..ad17538 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4039,6 +4039,24 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "subtle" version = "2.6.1" @@ -4159,7 +4177,7 @@ dependencies = [ [[package]] name = "tee-worker-pre-compute" -version = "0.2.0" +version = "0.3.0" dependencies = [ "aes", "alloy-signer", @@ -4175,6 +4193,8 @@ dependencies = [ "serde_json", "sha256", "sha3", + "strum", + "strum_macros", "temp-env", "tempfile", "testcontainers", diff --git a/pre-compute/Cargo.toml b/pre-compute/Cargo.toml index 8ecd4a2..c821816 100644 --- a/pre-compute/Cargo.toml +++ b/pre-compute/Cargo.toml @@ -16,6 +16,8 @@ reqwest = { version = "0.12.15", features = ["blocking", "json"] } serde = "1.0.219" sha256 = "1.6.0" sha3 = "0.10.8" +strum = "0.27.2" +strum_macros = "0.27.2" thiserror = "2.0.12" [dev-dependencies] diff --git a/pre-compute/src/api/worker_api.rs b/pre-compute/src/api/worker_api.rs index 3aeeec4..838e332 100644 --- a/pre-compute/src/api/worker_api.rs +++ b/pre-compute/src/api/worker_api.rs @@ -4,40 +4,6 @@ use crate::compute::{ }; use log::error; use reqwest::{blocking::Client, header::AUTHORIZATION}; -use serde::Serialize; - -/// Represents payload that can be sent to the worker API to report the outcome of the -/// pre‑compute stage. -/// -/// The JSON structure expected by the REST endpoint is: -/// ```json -/// { -/// "cause": "" -/// } -/// ``` -/// -/// # Arguments -/// -/// * `cause` - A reference to the ReplicateStatusCause indicating why the pre-compute operation exited -/// -/// # Example -/// -/// ```rust -/// use tee_worker_pre_compute::api::worker_api::ExitMessage; -/// use tee_worker_pre_compute::compute::errors::ReplicateStatusCause; -/// -/// let exit_message = ExitMessage::from(&ReplicateStatusCause::PreComputeInvalidTeeSignature); -/// ``` -#[derive(Serialize, Debug)] -pub struct ExitMessage<'a> { - pub cause: &'a ReplicateStatusCause, -} - -impl<'a> From<&'a ReplicateStatusCause> for ExitMessage<'a> { - fn from(cause: &'a ReplicateStatusCause) -> Self { - Self { cause } - } -} /// Thin wrapper around a [`Client`] that knows how to reach the iExec worker API. /// @@ -93,21 +59,21 @@ impl WorkerApiClient { Self::new(&base_url) } - /// Sends an exit cause for a pre-compute operation to the Worker API. + /// Sends exit causes for a pre-compute operation to the Worker API. /// - /// This method reports the exit cause of a pre-compute operation to the Worker API, + /// This method reports the exit causes of a pre-compute operation to the Worker API, /// which can be used for tracking and debugging purposes. /// /// # Arguments /// /// * `authorization` - The authorization token to use for the API request - /// * `chain_task_id` - The chain task ID for which to report the exit cause - /// * `exit_cause` - The exit cause to report + /// * `chain_task_id` - The chain task ID for which to report the exit causes + /// * `exit_causes` - The list of exit causes to report /// /// # Returns /// - /// * `Ok(())` - If the exit cause was successfully reported - /// * `Err(Error)` - If the exit cause could not be reported due to an HTTP error + /// * `Ok(())` - If the exit causes were successfully reported + /// * `Err(Error)` - If the exit causes could not be reported due to an HTTP error /// /// # Errors /// @@ -117,33 +83,33 @@ impl WorkerApiClient { /// # Example /// /// ```rust - /// use tee_worker_pre_compute::api::worker_api::{ExitMessage, WorkerApiClient}; + /// use tee_worker_pre_compute::api::worker_api::WorkerApiClient; /// use tee_worker_pre_compute::compute::errors::ReplicateStatusCause; /// /// let client = WorkerApiClient::new("http://worker:13100"); - /// let exit_message = ExitMessage::from(&ReplicateStatusCause::PreComputeInvalidTeeSignature); + /// let exit_causes = vec![ReplicateStatusCause::PreComputeInvalidTeeSignature]; /// - /// match client.send_exit_cause_for_pre_compute_stage( + /// match client.send_exit_causes_for_pre_compute_stage( /// "authorization_token", /// "0x123456789abcdef", - /// &exit_message, + /// &exit_causes, /// ) { - /// Ok(()) => println!("Exit cause reported successfully"), - /// Err(error) => eprintln!("Failed to report exit cause: {error}"), + /// Ok(()) => println!("Exit causes reported successfully"), + /// Err(error) => eprintln!("Failed to report exit causes: {error}"), /// } /// ``` - pub fn send_exit_cause_for_pre_compute_stage( + pub fn send_exit_causes_for_pre_compute_stage( &self, authorization: &str, chain_task_id: &str, - exit_cause: &ExitMessage, + exit_causes: &[ReplicateStatusCause], ) -> Result<(), ReplicateStatusCause> { - let url = format!("{}/compute/pre/{chain_task_id}/exit", self.base_url); + let url = format!("{}/compute/pre/{chain_task_id}/exit-causes", self.base_url); match self .client .post(&url) .header(AUTHORIZATION, authorization) - .json(exit_cause) + .json(exit_causes) .send() { Ok(resp) => { @@ -152,12 +118,12 @@ impl WorkerApiClient { Ok(()) } else { let body = resp.text().unwrap_or_default(); - error!("Failed to send exit cause: [status:{status}, body:{body}]"); + error!("Failed to send exit causes: [status:{status}, body:{body}]"); Err(ReplicateStatusCause::PreComputeFailedUnknownIssue) } } Err(err) => { - error!("HTTP request failed when sending exit cause to {url}: {err:?}"); + error!("HTTP request failed when sending exit causes to {url}: {err:?}"); Err(ReplicateStatusCause::PreComputeFailedUnknownIssue) } } @@ -175,36 +141,52 @@ mod tests { matchers::{body_json, header, method, path}, }; - // region ExitMessage() + // region Serialization tests #[test] - fn should_serialize_exit_message() { - let causes = [ + fn serialize_replicate_status_cause_succeeds_when_single_cause() { + let causes = vec![ ( ReplicateStatusCause::PreComputeInvalidTeeSignature, - "PRE_COMPUTE_INVALID_TEE_SIGNATURE", + r#"{"cause":"PRE_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"}"#, ), ( ReplicateStatusCause::PreComputeWorkerAddressMissing, - "PRE_COMPUTE_WORKER_ADDRESS_MISSING", + r#"{"cause":"PRE_COMPUTE_WORKER_ADDRESS_MISSING","message":"Worker address related environment variable is missing"}"#, ), ( - ReplicateStatusCause::PreComputeFailedUnknownIssue, - "PRE_COMPUTE_FAILED_UNKNOWN_ISSUE", + ReplicateStatusCause::PreComputeDatasetUrlMissing("0xDatasetAdress1".to_string()), + r#"{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAdress1"}"#, + ), + ( + ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + "0xDatasetAdress2".to_string(), + ), + r#"{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xDatasetAdress2"}"#, ), ]; - for (cause, message) in causes { - let exit_message = ExitMessage::from(&cause); - let serialized = to_string(&exit_message).expect("Failed to serialize"); - let expected = format!("{{\"cause\":\"{message}\"}}"); - assert_eq!(serialized, expected); + for (cause, expected_json) in causes { + let serialized = to_string(&cause).expect("Failed to serialize"); + assert_eq!(serialized, expected_json); } } + + #[test] + fn serialize_vec_of_causes_succeeds_when_multiple_causes() { + let causes = vec![ + ReplicateStatusCause::PreComputeDatasetUrlMissing("0xDatasetAdress".to_string()), + ReplicateStatusCause::PreComputeInvalidDatasetChecksum("0xDatasetAdress".to_string()), + ]; + + let serialized = to_string(&causes).expect("Failed to serialize"); + let expected = r#"[{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAdress"},{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xDatasetAdress"}]"#; + assert_eq!(serialized, expected); + } // endregion // region get_worker_api_client #[test] - fn should_get_worker_api_client_with_env_var() { + fn from_env_creates_client_with_custom_host_when_env_var_set() { with_vars( vec![(WorkerHostEnvVar.name(), Some("custom-worker-host:9999"))], || { @@ -215,7 +197,7 @@ mod tests { } #[test] - fn should_get_worker_api_client_without_env_var() { + fn from_env_creates_client_with_default_host_when_env_var_unset() { temp_env::with_vars_unset(vec![WorkerHostEnvVar.name()], || { let client = WorkerApiClient::from_env(); assert_eq!(client.base_url, format!("http://{DEFAULT_WORKER_HOST}")); @@ -223,21 +205,24 @@ mod tests { } // endregion - // region send_exit_cause_for_pre_compute_stage() + // region send_exit_causes_for_pre_compute_stage() const CHALLENGE: &str = "challenge"; const CHAIN_TASK_ID: &str = "0x123456789abcdef"; #[tokio::test] - async fn should_send_exit_cause() { + async fn send_exit_causes_succeeds_when_api_returns_success() { let mock_server = MockServer::start().await; let server_url = mock_server.uri(); - let expected_body = json!({ - "cause": ReplicateStatusCause::PreComputeInvalidTeeSignature, - }); + let expected_body = json!([ + { + "cause": "PRE_COMPUTE_INVALID_TEE_SIGNATURE", + "message": "Invalid TEE signature" + } + ]); Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .and(header("Authorization", CHALLENGE)) .and(body_json(&expected_body)) .respond_with(ResponseTemplate::new(200)) @@ -246,13 +231,12 @@ mod tests { .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PreComputeInvalidTeeSignature); + let exit_causes = vec![ReplicateStatusCause::PreComputeInvalidTeeSignature]; let worker_api_client = WorkerApiClient::new(&server_url); - worker_api_client.send_exit_cause_for_pre_compute_stage( + worker_api_client.send_exit_causes_for_pre_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ) }) .await @@ -262,26 +246,25 @@ mod tests { } #[tokio::test] - async fn should_not_send_exit_cause() { + async fn send_exit_causes_fails_when_api_returns_error() { testing_logger::setup(); let mock_server = MockServer::start().await; let server_url = mock_server.uri(); Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .respond_with(ResponseTemplate::new(503).set_body_string("Service Unavailable")) .expect(1) .mount(&mock_server) .await; let result = tokio::task::spawn_blocking(move || { - let exit_message = - ExitMessage::from(&ReplicateStatusCause::PreComputeFailedUnknownIssue); + let exit_causes = vec![ReplicateStatusCause::PreComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new(&server_url); - let response = worker_api_client.send_exit_cause_for_pre_compute_stage( + let response = worker_api_client.send_exit_causes_for_pre_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ); testing_logger::validate(|captured_logs| { let logs = captured_logs @@ -292,7 +275,7 @@ mod tests { assert_eq!(logs.len(), 1); assert_eq!( logs[0].body, - "Failed to send exit cause: [status:503 Service Unavailable, body:Service Unavailable]" + "Failed to send exit causes: [status:503 Service Unavailable, body:Service Unavailable]" ); }); response @@ -308,14 +291,14 @@ mod tests { } #[test] - fn test_send_exit_cause_http_request_failure() { + fn send_exit_causes_fails_when_http_request_invalid() { testing_logger::setup(); - let exit_message = ExitMessage::from(&ReplicateStatusCause::PreComputeFailedUnknownIssue); + let exit_causes = vec![ReplicateStatusCause::PreComputeFailedUnknownIssue]; let worker_api_client = WorkerApiClient::new("wrong_url"); - let result = worker_api_client.send_exit_cause_for_pre_compute_stage( + let result = worker_api_client.send_exit_causes_for_pre_compute_stage( CHALLENGE, CHAIN_TASK_ID, - &exit_message, + &exit_causes, ); testing_logger::validate(|captured_logs| { let logs = captured_logs @@ -326,7 +309,7 @@ mod tests { assert_eq!(logs.len(), 1); assert_eq!( logs[0].body, - "HTTP request failed when sending exit cause to wrong_url/compute/pre/0x123456789abcdef/exit: reqwest::Error { kind: Builder, source: RelativeUrlWithoutBase }" + "HTTP request failed when sending exit causes to wrong_url/compute/pre/0x123456789abcdef/exit-causes: reqwest::Error { kind: Builder, source: RelativeUrlWithoutBase }" ); }); assert!(result.is_err()); diff --git a/pre-compute/src/compute/app_runner.rs b/pre-compute/src/compute/app_runner.rs index 40a1586..093a47b 100644 --- a/pre-compute/src/compute/app_runner.rs +++ b/pre-compute/src/compute/app_runner.rs @@ -1,4 +1,4 @@ -use crate::api::worker_api::{ExitMessage, WorkerApiClient}; +use crate::api::worker_api::WorkerApiClient; use crate::compute::pre_compute_app::{PreComputeApp, PreComputeAppTrait}; use crate::compute::{ errors::ReplicateStatusCause, @@ -61,14 +61,12 @@ pub fn start_with_app( } }; - let exit_message = ExitMessage { - cause: &exit_cause.clone(), - }; + let exit_causes = vec![exit_cause.clone()]; - match WorkerApiClient::from_env().send_exit_cause_for_pre_compute_stage( + match WorkerApiClient::from_env().send_exit_causes_for_pre_compute_stage( &authorization, chain_task_id, - &exit_message, + &exit_causes, ) { Ok(_) => ExitMode::ReportedFailure, Err(_) => { @@ -193,7 +191,7 @@ mod pre_compute_start_with_app_tests { let mock_server = MockServer::start().await; Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .respond_with(ResponseTemplate::new(500)) .mount(&mock_server) .await; @@ -231,14 +229,14 @@ mod pre_compute_start_with_app_tests { async fn start_succeeds_when_send_exit_cause_api_success() { let mock_server = MockServer::start().await; - let expected_cause_enum = ReplicateStatusCause::PreComputeOutputFolderNotFound; - let expected_exit_message_payload = json!({ - "cause": expected_cause_enum // Relies on ReplicateStatusCause's Serialize impl - }); + let expected_exit_message_payload = json!([{ + "cause": "PRE_COMPUTE_OUTPUT_FOLDER_NOT_FOUND", + "message": "Output folder related environment variable is missing" + }]); // Mock the worker API to return success Mock::given(method("POST")) - .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit"))) + .and(path(format!("/compute/pre/{CHAIN_TASK_ID}/exit-causes"))) .and(body_json(expected_exit_message_payload)) .respond_with(ResponseTemplate::new(200)) .expect(1) diff --git a/pre-compute/src/compute/dataset.rs b/pre-compute/src/compute/dataset.rs index 33003d0..f65543f 100644 --- a/pre-compute/src/compute/dataset.rs +++ b/pre-compute/src/compute/dataset.rs @@ -79,7 +79,9 @@ impl Dataset { } else { download_from_url(&self.url) } - .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed)?; + .ok_or(ReplicateStatusCause::PreComputeDatasetDownloadFailed( + self.filename.clone(), + ))?; info!("Checking encrypted dataset checksum [chainTaskId:{chain_task_id}]"); let actual_checksum = sha256_from_bytes(&encrypted_content); @@ -89,7 +91,9 @@ impl Dataset { "Invalid dataset checksum [chainTaskId:{chain_task_id}, expected:{}, actual:{actual_checksum}]", self.checksum ); - return Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); + return Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + self.filename.clone(), + )); } info!("Dataset downloaded and verified successfully."); @@ -113,12 +117,14 @@ impl Dataset { &self, encrypted_content: &[u8], ) -> Result, ReplicateStatusCause> { - let key = general_purpose::STANDARD - .decode(&self.key) - .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed)?; + let key = general_purpose::STANDARD.decode(&self.key).map_err(|_| { + ReplicateStatusCause::PreComputeDatasetDecryptionFailed(self.filename.clone()) + })?; if encrypted_content.len() < AES_IV_LENGTH || key.len() != AES_KEY_LENGTH { - return Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed); + return Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed( + self.filename.clone(), + )); } let key_slice = &key[..AES_KEY_LENGTH]; @@ -127,7 +133,9 @@ impl Dataset { Aes256CbcDec::new(key_slice.into(), iv_slice.into()) .decrypt_padded_vec_mut::(ciphertext) - .map_err(|_| ReplicateStatusCause::PreComputeDatasetDecryptionFailed) + .map_err(|_| { + ReplicateStatusCause::PreComputeDatasetDecryptionFailed(self.filename.clone()) + }) } } @@ -144,7 +152,7 @@ mod tests { "0x02a12ef127dcfbdb294a090c8f0b69a0ca30b7940fc36cabf971f488efd374d7"; const ENCRYPTED_DATASET_KEY: &str = "ubA6H9emVPJT91/flYAmnKHC0phSV3cfuqsLxQfgow0="; const HTTP_DATASET_URL: &str = "https://raw.githubusercontent.com/iExecBlockchainComputing/tee-worker-pre-compute-rust/main/src/tests_resources/encrypted-data.bin"; - const PLAIN_DATA_FILE: &str = "plain-data.txt"; + const PLAIN_DATA_FILE: &str = "0xDatasetAddress"; const IPFS_DATASET_URL: &str = "/ipfs/QmUVhChbLFiuzNK1g2GsWyWEiad7SXPqARnWzGumgziwEp"; fn get_test_dataset() -> Dataset { @@ -171,7 +179,9 @@ mod tests { let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); assert_eq!( actual_content, - Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed) + Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed( + PLAIN_DATA_FILE.to_string() + )) ); } @@ -191,7 +201,9 @@ mod tests { let mut dataset = get_test_dataset(); dataset.url = "/ipfs/INVALID_IPFS_DATASET_URL".to_string(); let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); - let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed); + let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed( + PLAIN_DATA_FILE.to_string(), + )); assert_eq!(actual_content, expected_content); } @@ -200,7 +212,9 @@ mod tests { let mut dataset = get_test_dataset(); dataset.checksum = "invalid_dataset_checksum".to_string(); let actual_content = dataset.download_encrypted_dataset(CHAIN_TASK_ID); - let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum); + let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + PLAIN_DATA_FILE.to_string(), + )); assert_eq!(actual_content, expected_content); } // endregion @@ -226,7 +240,9 @@ mod tests { assert_eq!( actual_plain_data, - Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed) + Err(ReplicateStatusCause::PreComputeDatasetDecryptionFailed( + PLAIN_DATA_FILE.to_string() + )) ); } // endregion diff --git a/pre-compute/src/compute/errors.rs b/pre-compute/src/compute/errors.rs index 51ceace..f9981de 100644 --- a/pre-compute/src/compute/errors.rs +++ b/pre-compute/src/compute/errors.rs @@ -1,24 +1,26 @@ -use serde::{Deserialize, Serialize}; +use serde::{Serializer, ser::SerializeStruct}; +use strum_macros::EnumDiscriminants; use thiserror::Error; -#[derive(Debug, PartialEq, Clone, Error, Serialize, Deserialize)] -#[serde(rename_all(serialize = "SCREAMING_SNAKE_CASE"))] +#[derive(Debug, PartialEq, Clone, Error, EnumDiscriminants)] +#[strum_discriminants(derive(serde::Serialize))] +#[strum_discriminants(serde(rename_all = "SCREAMING_SNAKE_CASE"))] #[allow(clippy::enum_variant_names)] pub enum ReplicateStatusCause { - #[error("At least one input file URL is missing")] - PreComputeAtLeastOneInputFileUrlMissing, - #[error("Dataset checksum related environment variable is missing")] - PreComputeDatasetChecksumMissing, - #[error("Failed to decrypt dataset")] - PreComputeDatasetDecryptionFailed, - #[error("Failed to download encrypted dataset file")] - PreComputeDatasetDownloadFailed, - #[error("Dataset filename related environment variable is missing")] - PreComputeDatasetFilenameMissing, - #[error("Dataset key related environment variable is missing")] - PreComputeDatasetKeyMissing, - #[error("Dataset URL related environment variable is missing")] - PreComputeDatasetUrlMissing, + #[error("input file URL {0} is missing")] + PreComputeAtLeastOneInputFileUrlMissing(usize), + #[error("Dataset checksum related environment variable is missing for dataset {0}")] + PreComputeDatasetChecksumMissing(String), + #[error("Failed to decrypt dataset {0}")] + PreComputeDatasetDecryptionFailed(String), + #[error("Failed to download encrypted dataset file for dataset {0}")] + PreComputeDatasetDownloadFailed(String), + #[error("Dataset filename related environment variable is missing for dataset {0}")] + PreComputeDatasetFilenameMissing(String), + #[error("Dataset key related environment variable is missing for dataset {0}")] + PreComputeDatasetKeyMissing(String), + #[error("Dataset URL related environment variable is missing for dataset {0}")] + PreComputeDatasetUrlMissing(String), #[error("Unexpected error occurred")] PreComputeFailedUnknownIssue, #[error("Invalid TEE signature")] @@ -29,9 +31,9 @@ pub enum ReplicateStatusCause { PreComputeInputFileDownloadFailed, #[error("Input files number related environment variable is missing")] PreComputeInputFilesNumberMissing, - #[error("Invalid dataset checksum")] - PreComputeInvalidDatasetChecksum, - #[error("Input files number related environment variable is missing")] + #[error("Invalid dataset checksum for dataset {0}")] + PreComputeInvalidDatasetChecksum(String), + #[error("Output folder related environment variable is missing")] PreComputeOutputFolderNotFound, #[error("Output path related environment variable is missing")] PreComputeOutputPathMissing, @@ -44,3 +46,92 @@ pub enum ReplicateStatusCause { #[error("Worker address related environment variable is missing")] PreComputeWorkerAddressMissing, } + +impl serde::Serialize for ReplicateStatusCause { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("ReplicateStatusCause", 2)?; + state.serialize_field("cause", &ReplicateStatusCauseDiscriminants::from(self))?; + state.serialize_field("message", &self.to_string())?; + state.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::to_string; + + const DATASET_FILENAME: &str = "0xDatasetAddress"; + + #[test] + fn serialize_produces_correct_json_when_error_has_dataset_filename() { + let cause = ReplicateStatusCause::PreComputeDatasetUrlMissing(DATASET_FILENAME.to_string()); + let serialized = to_string(&cause).unwrap(); + assert_eq!( + serialized, + r#"{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAddress"}"# + ); + } + + #[test] + fn serialize_produces_correct_json_when_error_has_no_index() { + let cause = ReplicateStatusCause::PreComputeInvalidTeeSignature; + let serialized = to_string(&cause).unwrap(); + assert_eq!( + serialized, + r#"{"cause":"PRE_COMPUTE_INVALID_TEE_SIGNATURE","message":"Invalid TEE signature"}"# + ); + } + + #[test] + fn serialize_produces_correct_json_when_multiple_dataset_errors_with_filenames() { + let test_cases = vec![ + ( + ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing(1), + r#"{"cause":"PRE_COMPUTE_AT_LEAST_ONE_INPUT_FILE_URL_MISSING","message":"input file URL 1 is missing"}"#, + ), + ( + ReplicateStatusCause::PreComputeDatasetChecksumMissing( + DATASET_FILENAME.to_string(), + ), + r#"{"cause":"PRE_COMPUTE_DATASET_CHECKSUM_MISSING","message":"Dataset checksum related environment variable is missing for dataset 0xDatasetAddress"}"#, + ), + ( + ReplicateStatusCause::PreComputeDatasetDecryptionFailed( + DATASET_FILENAME.to_string(), + ), + r#"{"cause":"PRE_COMPUTE_DATASET_DECRYPTION_FAILED","message":"Failed to decrypt dataset 0xDatasetAddress"}"#, + ), + ( + ReplicateStatusCause::PreComputeDatasetDownloadFailed(DATASET_FILENAME.to_string()), + r#"{"cause":"PRE_COMPUTE_DATASET_DOWNLOAD_FAILED","message":"Failed to download encrypted dataset file for dataset 0xDatasetAddress"}"#, + ), + ( + ReplicateStatusCause::PreComputeInvalidDatasetChecksum( + DATASET_FILENAME.to_string(), + ), + r#"{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xDatasetAddress"}"#, + ), + ]; + + for (cause, expected) in test_cases { + let serialized = to_string(&cause).unwrap(); + assert_eq!(serialized, expected); + } + } + + #[test] + fn serialize_produces_correct_json_when_vector_of_multiple_errors() { + let causes = vec![ + ReplicateStatusCause::PreComputeDatasetUrlMissing(DATASET_FILENAME.to_string()), + ReplicateStatusCause::PreComputeInvalidDatasetChecksum("0xAnotherDataset".to_string()), + ]; + + let serialized = to_string(&causes).unwrap(); + let expected = r#"[{"cause":"PRE_COMPUTE_DATASET_URL_MISSING","message":"Dataset URL related environment variable is missing for dataset 0xDatasetAddress"},{"cause":"PRE_COMPUTE_INVALID_DATASET_CHECKSUM","message":"Invalid dataset checksum for dataset 0xAnotherDataset"}]"#; + assert_eq!(serialized, expected); + } +} diff --git a/pre-compute/src/compute/pre_compute_app.rs b/pre-compute/src/compute/pre_compute_app.rs index ca12b10..825f3ef 100644 --- a/pre-compute/src/compute/pre_compute_app.rs +++ b/pre-compute/src/compute/pre_compute_app.rs @@ -59,7 +59,7 @@ impl PreComputeAppTrait for PreComputeApp { // TODO: Collect all errors instead of propagating immediately, and return the list of errors self.pre_compute_args = PreComputeArgs::read_args()?; self.check_output_folder()?; - for dataset in &self.pre_compute_args.datasets { + for dataset in self.pre_compute_args.datasets.iter() { let encrypted_content = dataset.download_encrypted_dataset(&self.chain_task_id)?; let plain_content = dataset.decrypt_dataset(&encrypted_content)?; self.save_plain_dataset_file(&plain_content, &dataset.filename)?; diff --git a/pre-compute/src/compute/pre_compute_args.rs b/pre-compute/src/compute/pre_compute_args.rs index 1bd074c..e230afb 100644 --- a/pre-compute/src/compute/pre_compute_args.rs +++ b/pre-compute/src/compute/pre_compute_args.rs @@ -86,21 +86,21 @@ impl PreComputeArgs { // Read datasets let start_index = if is_dataset_required { 0 } else { 1 }; for i in start_index..=iexec_bulk_slice_size { + let filename = get_env_var_or_error( + TeeSessionEnvironmentVariable::IexecDatasetFilename(i), + ReplicateStatusCause::PreComputeDatasetFilenameMissing(format!("dataset_{i}")), + )?; let url = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetUrl(i), - ReplicateStatusCause::PreComputeDatasetUrlMissing, // TODO: replace with a more specific error for bulk dataset + ReplicateStatusCause::PreComputeDatasetUrlMissing(filename.clone()), )?; let checksum = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetChecksum(i), - ReplicateStatusCause::PreComputeDatasetChecksumMissing, // TODO: replace with a more specific error for bulk dataset - )?; - let filename = get_env_var_or_error( - TeeSessionEnvironmentVariable::IexecDatasetFilename(i), - ReplicateStatusCause::PreComputeDatasetFilenameMissing, // TODO: replace with a more specific error for bulk dataset + ReplicateStatusCause::PreComputeDatasetChecksumMissing(filename.clone()), )?; let key = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecDatasetKey(i), - ReplicateStatusCause::PreComputeDatasetKeyMissing, // TODO: replace with a more specific error for bulk dataset + ReplicateStatusCause::PreComputeDatasetKeyMissing(filename.clone()), )?; datasets.push(Dataset::new(url, checksum, filename, key)); @@ -118,7 +118,7 @@ impl PreComputeArgs { for i in 1..=input_files_nb { let url = get_env_var_or_error( TeeSessionEnvironmentVariable::IexecInputFileUrlPrefix(i), - ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing, + ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing(i), )?; input_files.push(url); } @@ -427,7 +427,7 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetUrlMissing + ReplicateStatusCause::PreComputeDatasetUrlMissing("bulk-dataset-1.txt".to_string()) ); }); } @@ -446,7 +446,9 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetChecksumMissing + ReplicateStatusCause::PreComputeDatasetChecksumMissing( + "bulk-dataset-2.txt".to_string() + ) ); }); } @@ -465,7 +467,7 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetFilenameMissing + ReplicateStatusCause::PreComputeDatasetFilenameMissing("dataset_2".to_string()) ); }); } @@ -484,7 +486,7 @@ mod tests { assert!(result.is_err()); assert_eq!( result.unwrap_err(), - ReplicateStatusCause::PreComputeDatasetKeyMissing + ReplicateStatusCause::PreComputeDatasetKeyMissing("bulk-dataset-1.txt".to_string()) ); }); } @@ -508,23 +510,25 @@ mod tests { ), ( IexecDatasetUrl(0), - ReplicateStatusCause::PreComputeDatasetUrlMissing, + ReplicateStatusCause::PreComputeDatasetUrlMissing(DATASET_FILENAME.to_string()), ), ( IexecDatasetKey(0), - ReplicateStatusCause::PreComputeDatasetKeyMissing, + ReplicateStatusCause::PreComputeDatasetKeyMissing(DATASET_FILENAME.to_string()), ), ( IexecDatasetChecksum(0), - ReplicateStatusCause::PreComputeDatasetChecksumMissing, + ReplicateStatusCause::PreComputeDatasetChecksumMissing( + DATASET_FILENAME.to_string(), + ), ), ( IexecDatasetFilename(0), - ReplicateStatusCause::PreComputeDatasetFilenameMissing, + ReplicateStatusCause::PreComputeDatasetFilenameMissing("dataset_0".to_string()), ), ( IexecInputFileUrlPrefix(1), - ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing, + ReplicateStatusCause::PreComputeAtLeastOneInputFileUrlMissing(1), ), ]; for (env_var, error) in missing_env_var_causes { diff --git a/pre-compute/src/compute/utils/env_utils.rs b/pre-compute/src/compute/utils/env_utils.rs index 270f0d6..72598d5 100644 --- a/pre-compute/src/compute/utils/env_utils.rs +++ b/pre-compute/src/compute/utils/env_utils.rs @@ -71,6 +71,8 @@ mod tests { use super::*; use temp_env; + const DATASET_ADDRESS: &str = "0xDatasetAddress"; + #[test] fn name_succeeds_when_simple_environment_variable_names() { assert_eq!( @@ -202,7 +204,8 @@ mod tests { #[test] fn get_env_var_or_error_succeeds_when_indexed_variables() { let env_var = TeeSessionEnvironmentVariable::IexecDatasetChecksum(1); - let status_cause = ReplicateStatusCause::PreComputeDatasetChecksumMissing; + let status_cause = + ReplicateStatusCause::PreComputeDatasetChecksumMissing(DATASET_ADDRESS.to_string()); temp_env::with_var("IEXEC_DATASET_1_CHECKSUM", Some("abc123def456"), || { let result = get_env_var_or_error(env_var, status_cause.clone());