From 5f07a3aaaea5ae50ecc4c22ac7ec3cbbb77593a0 Mon Sep 17 00:00:00 2001 From: Andy Gayton Date: Thu, 28 Nov 2024 17:17:32 -0500 Subject: [PATCH] feat(handlers): validate closure shape and emit .unregister on registration error --- src/api.rs | 2 +- src/handlers/handler.rs | 53 +++++++++++++----- src/handlers/serve.rs | 119 +++++++++++++++++++++++++++++++++------- 3 files changed, 138 insertions(+), 36 deletions(-) diff --git a/src/api.rs b/src/api.rs index 1698d3b..7abd5c7 100644 --- a/src/api.rs +++ b/src/api.rs @@ -276,7 +276,7 @@ async fn handle_process_post( Meta::default(), engine.clone(), script, - ); + )?; let value = handler.eval_in_thread(&pool, &frame).await; let json = nu::value_to_json(&value); diff --git a/src/handlers/handler.rs b/src/handlers/handler.rs index bd93176..5fcdeef 100644 --- a/src/handlers/handler.rs +++ b/src/handlers/handler.rs @@ -1,12 +1,17 @@ -use crate::error::Error; -use crate::nu; -use crate::thread_pool::ThreadPool; use nu_engine::eval_block_with_early_return; use nu_protocol::debugger::WithoutDebug; use nu_protocol::engine::Stack; +use nu_protocol::engine::StateWorkingSet; +use nu_protocol::PipelineData; use nu_protocol::{Span, Value}; + use scru128::Scru128Id; +use crate::error::Error; +use crate::nu; +use crate::nu::frame_to_value; +use crate::thread_pool::ThreadPool; + #[derive(Clone, Debug, serde::Deserialize)] #[serde(untagged)] pub enum StartDefinition { @@ -15,7 +20,6 @@ pub enum StartDefinition { #[derive(Clone, Debug, serde::Deserialize, Default)] pub struct Meta { - pub stateful: Option, pub initial_state: Option, pub pulse: Option, pub start: Option, @@ -28,6 +32,7 @@ pub struct Handler { pub meta: Meta, pub engine: nu::Engine, pub closure: nu_protocol::engine::Closure, + pub stateful: bool, pub state: Option, } @@ -38,19 +43,33 @@ impl Handler { meta: Meta, mut engine: nu::Engine, expression: String, - ) -> Self { - let closure = engine.parse_closure(&expression).unwrap(); + ) -> Result { + let closure = engine.parse_closure(&expression)?; + let block = engine.state.get_block(closure.block_id); + + // Validate closure has 1 or 2 args and set stateful + let arg_count = block.signature.required_positional.len(); + let stateful = match arg_count { + 1 => false, + 2 => true, + _ => { + return Err( + format!("Closure must accept 1 or 2 arguments, found {}", arg_count).into(), + ) + } + }; - Self { + Ok(Self { id, topic, meta: meta.clone(), engine, closure, + stateful, state: meta .initial_state .map(|state| crate::nu::util::json_to_value(&state, nu_protocol::Span::unknown())), - } + }) } pub async fn eval_in_thread(&self, pool: &ThreadPool, frame: &crate::store::Frame) -> Value { @@ -73,14 +92,18 @@ impl Handler { } fn eval(&self, frame: &crate::store::Frame) -> Result { - let input = nu::frame_to_pipeline(frame); - let block = self.engine.state.get_block(self.closure.block_id); let mut stack = Stack::new(); + let block = self.engine.state.get_block(self.closure.block_id); + + // First arg is always frame + let frame_var_id = block.signature.required_positional[0].var_id.unwrap(); + stack.add_var(frame_var_id, frame_to_value(frame, Span::unknown())); - if self.meta.stateful.unwrap_or(false) { - let var_id = block.signature.required_positional[0].var_id.unwrap(); + // Second arg is state if stateful + if self.stateful { + let state_var_id = block.signature.required_positional[1].var_id.unwrap(); stack.add_var( - var_id, + state_var_id, self.state .clone() .unwrap_or(Value::nothing(Span::unknown())), @@ -91,12 +114,12 @@ impl Handler { &self.engine.state, &mut stack, block, - input, + PipelineData::empty(), // no pipeline input, using args ); Ok(output .map_err(|err| { - let working_set = nu_protocol::engine::StateWorkingSet::new(&self.engine.state); + let working_set = StateWorkingSet::new(&self.engine.state); nu_protocol::format_shell_error(&working_set, &err) })? .into_value(Span::unknown())?) diff --git a/src/handlers/serve.rs b/src/handlers/serve.rs index 31a77c9..8c76893 100644 --- a/src/handlers/serve.rs +++ b/src/handlers/serve.rs @@ -132,7 +132,7 @@ async fn spawn( } let value = handler.eval_in_thread(&pool, &frame).await; - if handler.meta.stateful.unwrap_or(false) { + if handler.stateful { handle_result_stateful(&store, &mut handler, &frame, value).await; } else { handle_result_stateless(&store, &handler, &frame, value).await; @@ -189,15 +189,29 @@ pub async fn serve( .await .unwrap(); - let handler = Handler::new( + match Handler::new( frame.id, topic.to_string(), meta.clone(), engine.clone(), expression, - ); - - let _ = spawn(store.clone(), handler, pool.clone()).await?; + ) { + Ok(handler) => { + let _ = spawn(store.clone(), handler, pool.clone()).await?; + } + Err(err) => { + let _ = store + .append( + Frame::with_topic(format!("{}.unregister", topic)) + .meta(serde_json::json!({ + "handler_id": frame.id.to_string(), + "error": err.to_string(), + })) + .build(), + ) + .await; + } + } } } @@ -230,8 +244,8 @@ mod tests { .hash( store .cas_insert( - r#"{|| - if $in.topic != "topic2" { return } + r#"{|frame| + if $frame.topic != "topic2" { return } "ran action" }"#, ) @@ -297,18 +311,17 @@ mod tests { .hash( store .cas_insert( - r#"{|state| - if $in.topic != "count.me" { return } - mut state = $state - $state.count += 1 - { state: $state } - }"#, + r#"{|frame, state| + if $frame.topic != "count.me" { return } + mut state = $state + $state.count += 1 + { state: $state } + }"#, ) .await .unwrap(), ) .meta(serde_json::json!({ - "stateful": true, "initial_state": { "count": 0 } })) .build(), @@ -386,8 +399,8 @@ mod tests { .hash( store .cas_insert( - r#"{|| - if $in.topic != "pew" { return } + r#"{|frame| + if $frame.topic != "pew" { return } "0.1" }"#, ) @@ -425,8 +438,8 @@ mod tests { .hash( store .cas_insert( - r#"{|| - if $in.topic != "pew" { return } + r#"{|frame| + if $frame.topic != "pew" { return } "0.2" }"#, ) @@ -554,8 +567,8 @@ mod tests { .hash( store .cas_insert( - r#"{|| - $in + r#"{|frame| + $frame }"#, ) .await @@ -589,4 +602,70 @@ mod tests { } } } + + #[tokio::test] + async fn test_register_invalid_closure() { + let temp_dir = TempDir::new().unwrap(); + let store = Store::new(temp_dir.into_path()).await; + let pool = ThreadPool::new(4); + let engine = nu::Engine::new(store.clone()).unwrap(); + + { + let store = store.clone(); + let _ = tokio::spawn(async move { + serve(store, engine, pool).await.unwrap(); + }); + } + + let options = ReadOptions::builder().follow(FollowOption::On).build(); + let mut recver = store.read(options).await; + + assert_eq!( + recver.recv().await.unwrap().topic, + "xs.threshold".to_string() + ); + + // Attempt to register a closure with no arguments + let _ = store + .append( + Frame::with_topic("invalid.register") + .hash( + store + .cas_insert( + r#"{|| 42 }"#, // Invalid closure, expects at least one argument + ) + .await + .unwrap(), + ) + .build(), + ) + .await; + + // Ensure the register frame is processed + assert_eq!( + recver.recv().await.unwrap().topic, + "invalid.register".to_string() + ); + + // Expect an unregister frame to be appended + let frame = recver.recv().await.unwrap(); + assert_eq!(frame.topic, "invalid.unregister".to_string()); + + // Verify the content of the error frame + let meta = frame.meta.unwrap(); + let error_message = meta["error"].as_str().unwrap(); + assert!(error_message.contains("Closure must accept 1 or 2 arguments")); + + // Ensure no additional frames are processed + let timeout = tokio::time::sleep(std::time::Duration::from_millis(50)); + tokio::pin!(timeout); + tokio::select! { + Some(frame) = recver.recv() => { + panic!("Unexpected frame processed: {:?}", frame); + } + _ = &mut timeout => { + // Success - no additional frames were processed + } + } + } }