Skip to content

Commit

Permalink
feat(handlers): validate closure shape and emit .unregister on regist…
Browse files Browse the repository at this point in the history
…ration error
  • Loading branch information
cablehead committed Nov 28, 2024
1 parent f4f01f3 commit 5f07a3a
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async fn handle_process_post(
Meta::default(),
engine.clone(),
script,
);
)?;

let value = handler.eval_in_thread(&pool, &frame).await;
let json = nu::value_to_json(&value);
Expand Down
53 changes: 38 additions & 15 deletions src/handlers/handler.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use crate::error::Error;
use crate::nu;
use crate::thread_pool::ThreadPool;
use nu_engine::eval_block_with_early_return;
use nu_protocol::debugger::WithoutDebug;
use nu_protocol::engine::Stack;
use nu_protocol::engine::StateWorkingSet;
use nu_protocol::PipelineData;
use nu_protocol::{Span, Value};

use scru128::Scru128Id;

use crate::error::Error;
use crate::nu;
use crate::nu::frame_to_value;
use crate::thread_pool::ThreadPool;

#[derive(Clone, Debug, serde::Deserialize)]
#[serde(untagged)]
pub enum StartDefinition {
Expand All @@ -15,7 +20,6 @@ pub enum StartDefinition {

#[derive(Clone, Debug, serde::Deserialize, Default)]
pub struct Meta {
pub stateful: Option<bool>,
pub initial_state: Option<serde_json::Value>,
pub pulse: Option<u64>,
pub start: Option<StartDefinition>,
Expand All @@ -28,6 +32,7 @@ pub struct Handler {
pub meta: Meta,
pub engine: nu::Engine,
pub closure: nu_protocol::engine::Closure,
pub stateful: bool,
pub state: Option<Value>,
}

Expand All @@ -38,19 +43,33 @@ impl Handler {
meta: Meta,
mut engine: nu::Engine,
expression: String,
) -> Self {
let closure = engine.parse_closure(&expression).unwrap();
) -> Result<Self, Error> {
let closure = engine.parse_closure(&expression)?;
let block = engine.state.get_block(closure.block_id);

// Validate closure has 1 or 2 args and set stateful
let arg_count = block.signature.required_positional.len();
let stateful = match arg_count {
1 => false,
2 => true,
_ => {
return Err(
format!("Closure must accept 1 or 2 arguments, found {}", arg_count).into(),
)
}
};

Self {
Ok(Self {
id,
topic,
meta: meta.clone(),
engine,
closure,
stateful,
state: meta
.initial_state
.map(|state| crate::nu::util::json_to_value(&state, nu_protocol::Span::unknown())),
}
})
}

pub async fn eval_in_thread(&self, pool: &ThreadPool, frame: &crate::store::Frame) -> Value {
Expand All @@ -73,14 +92,18 @@ impl Handler {
}

fn eval(&self, frame: &crate::store::Frame) -> Result<Value, Error> {
let input = nu::frame_to_pipeline(frame);
let block = self.engine.state.get_block(self.closure.block_id);
let mut stack = Stack::new();
let block = self.engine.state.get_block(self.closure.block_id);

// First arg is always frame
let frame_var_id = block.signature.required_positional[0].var_id.unwrap();
stack.add_var(frame_var_id, frame_to_value(frame, Span::unknown()));

if self.meta.stateful.unwrap_or(false) {
let var_id = block.signature.required_positional[0].var_id.unwrap();
// Second arg is state if stateful
if self.stateful {
let state_var_id = block.signature.required_positional[1].var_id.unwrap();
stack.add_var(
var_id,
state_var_id,
self.state
.clone()
.unwrap_or(Value::nothing(Span::unknown())),
Expand All @@ -91,12 +114,12 @@ impl Handler {
&self.engine.state,
&mut stack,
block,
input,
PipelineData::empty(), // no pipeline input, using args
);

Ok(output
.map_err(|err| {
let working_set = nu_protocol::engine::StateWorkingSet::new(&self.engine.state);
let working_set = StateWorkingSet::new(&self.engine.state);
nu_protocol::format_shell_error(&working_set, &err)
})?
.into_value(Span::unknown())?)
Expand Down
119 changes: 99 additions & 20 deletions src/handlers/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async fn spawn(
}

let value = handler.eval_in_thread(&pool, &frame).await;
if handler.meta.stateful.unwrap_or(false) {
if handler.stateful {
handle_result_stateful(&store, &mut handler, &frame, value).await;
} else {
handle_result_stateless(&store, &handler, &frame, value).await;
Expand Down Expand Up @@ -189,15 +189,29 @@ pub async fn serve(
.await
.unwrap();

let handler = Handler::new(
match Handler::new(
frame.id,
topic.to_string(),
meta.clone(),
engine.clone(),
expression,
);

let _ = spawn(store.clone(), handler, pool.clone()).await?;
) {
Ok(handler) => {
let _ = spawn(store.clone(), handler, pool.clone()).await?;
}
Err(err) => {
let _ = store
.append(
Frame::with_topic(format!("{}.unregister", topic))
.meta(serde_json::json!({
"handler_id": frame.id.to_string(),
"error": err.to_string(),
}))
.build(),
)
.await;
}
}
}
}

Expand Down Expand Up @@ -230,8 +244,8 @@ mod tests {
.hash(
store
.cas_insert(
r#"{||
if $in.topic != "topic2" { return }
r#"{|frame|
if $frame.topic != "topic2" { return }
"ran action"
}"#,
)
Expand Down Expand Up @@ -297,18 +311,17 @@ mod tests {
.hash(
store
.cas_insert(
r#"{|state|
if $in.topic != "count.me" { return }
mut state = $state
$state.count += 1
{ state: $state }
}"#,
r#"{|frame, state|
if $frame.topic != "count.me" { return }
mut state = $state
$state.count += 1
{ state: $state }
}"#,
)
.await
.unwrap(),
)
.meta(serde_json::json!({
"stateful": true,
"initial_state": { "count": 0 }
}))
.build(),
Expand Down Expand Up @@ -386,8 +399,8 @@ mod tests {
.hash(
store
.cas_insert(
r#"{||
if $in.topic != "pew" { return }
r#"{|frame|
if $frame.topic != "pew" { return }
"0.1"
}"#,
)
Expand Down Expand Up @@ -425,8 +438,8 @@ mod tests {
.hash(
store
.cas_insert(
r#"{||
if $in.topic != "pew" { return }
r#"{|frame|
if $frame.topic != "pew" { return }
"0.2"
}"#,
)
Expand Down Expand Up @@ -554,8 +567,8 @@ mod tests {
.hash(
store
.cas_insert(
r#"{||
$in
r#"{|frame|
$frame
}"#,
)
.await
Expand Down Expand Up @@ -589,4 +602,70 @@ mod tests {
}
}
}

#[tokio::test]
async fn test_register_invalid_closure() {
let temp_dir = TempDir::new().unwrap();
let store = Store::new(temp_dir.into_path()).await;
let pool = ThreadPool::new(4);
let engine = nu::Engine::new(store.clone()).unwrap();

{
let store = store.clone();
let _ = tokio::spawn(async move {
serve(store, engine, pool).await.unwrap();
});
}

let options = ReadOptions::builder().follow(FollowOption::On).build();
let mut recver = store.read(options).await;

assert_eq!(
recver.recv().await.unwrap().topic,
"xs.threshold".to_string()
);

// Attempt to register a closure with no arguments
let _ = store
.append(
Frame::with_topic("invalid.register")
.hash(
store
.cas_insert(
r#"{|| 42 }"#, // Invalid closure, expects at least one argument
)
.await
.unwrap(),
)
.build(),
)
.await;

// Ensure the register frame is processed
assert_eq!(
recver.recv().await.unwrap().topic,
"invalid.register".to_string()
);

// Expect an unregister frame to be appended
let frame = recver.recv().await.unwrap();
assert_eq!(frame.topic, "invalid.unregister".to_string());

// Verify the content of the error frame
let meta = frame.meta.unwrap();
let error_message = meta["error"].as_str().unwrap();
assert!(error_message.contains("Closure must accept 1 or 2 arguments"));

// Ensure no additional frames are processed
let timeout = tokio::time::sleep(std::time::Duration::from_millis(50));
tokio::pin!(timeout);
tokio::select! {
Some(frame) = recver.recv() => {
panic!("Unexpected frame processed: {:?}", frame);
}
_ = &mut timeout => {
// Success - no additional frames were processed
}
}
}
}

0 comments on commit 5f07a3a

Please sign in to comment.