From 94d270473e1b8a81233c1d311dbd4bf218695dee Mon Sep 17 00:00:00 2001 From: kafonek Date: Thu, 23 Nov 2023 13:55:52 -0500 Subject: [PATCH] working on execute action support --- Cargo.toml | 1 + src/actions.rs | 4 ++ src/client.rs | 22 +++++++++- src/jupyter/message_content/execute.rs | 13 ++++++ src/jupyter/response.rs | 23 +++++++++++ src/main.rs | 3 +- tests/{test_messages.rs => test_commands.rs} | 43 ++++++++++++++------ 7 files changed, 94 insertions(+), 15 deletions(-) rename tests/{test_messages.rs => test_commands.rs} (62%) diff --git a/Cargo.toml b/Cargo.toml index b2f4885..645845b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ zeromq = "0.3.4" [dev-dependencies] rstest = "0.18.2" +serial_test = "2.0.0" diff --git a/src/actions.rs b/src/actions.rs index cd54b1d..6eb1f43 100644 --- a/src/actions.rs +++ b/src/actions.rs @@ -18,6 +18,7 @@ pub trait Handler: Debug + Send + Sync { #[derive(Debug, PartialEq)] pub enum ExpectedReplyType { KernelInfo, + ExecuteReply, None, } @@ -25,6 +26,7 @@ impl From<&Request> for ExpectedReplyType { fn from(request: &Request) -> Self { match request { Request::KernelInfo(_) => ExpectedReplyType::KernelInfo, + Request::Execute(_) => ExpectedReplyType::ExecuteReply, _ => ExpectedReplyType::None, } } @@ -34,6 +36,7 @@ impl From<&Response> for ExpectedReplyType { fn from(response: &Response) -> Self { match response { Response::KernelInfo(_) => ExpectedReplyType::KernelInfo, + Response::Execute(_) => ExpectedReplyType::ExecuteReply, _ => ExpectedReplyType::None, } } @@ -85,6 +88,7 @@ impl Action { let mut kernel_idle = false; let mut expected_reply_seen = match expected_reply { ExpectedReplyType::KernelInfo => false, + ExpectedReplyType::ExecuteReply => false, ExpectedReplyType::None => true, }; while let Some(response) = msg_rx.recv().await { diff --git a/src/client.rs b/src/client.rs index 7f0fbe3..10921c9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -46,6 +46,7 @@ use tokio::sync::{mpsc, Notify, RwLock}; use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket, ZmqMessage}; use crate::actions::{Action, Handler}; +use crate::jupyter::message_content::execute::ExecuteRequest; use crate::jupyter::message_content::kernel_info::KernelInfoRequest; use crate::jupyter::request::Request; use crate::jupyter::response::Response; @@ -151,6 +152,12 @@ impl Client { let action = self.send_request(request.into(), handlers).await; action } + + pub async fn execute_request(&self, code: String, handlers: Vec>) -> Action { + let request = ExecuteRequest::new(code); + let action = self.send_request(request.into(), handlers).await; + action + } } impl Drop for Client { @@ -173,7 +180,20 @@ async fn process_message_worker( let response: Response = zmq_msg.into(); let msg_id = response.msg_id(); if let Some(action) = actions.read().await.get(&msg_id) { - action.send(response).await.unwrap(); } + let sent = action.send(response).await; + // If we're seeing SendError here, it means we're still seeing ZMQ messages with + // parent header msg id matching a request / Action that is "completed" and has + // shut down its mpsc Receiver channel. That's probably happening because the + // Action is not configured to expect some Reply type and is "finishing" when + // Kernel status goes Idle but then we send along another Reply messages to a + // shutdown mpsc Receiver channel. + match sent { + Ok(_) => {}, + Err(e) => { + dbg!(e); + } + } + } }, _ = shutdown_signal.notified() => { break; diff --git a/src/jupyter/message_content/execute.rs b/src/jupyter/message_content/execute.rs index 7f96986..1ac722f 100644 --- a/src/jupyter/message_content/execute.rs +++ b/src/jupyter/message_content/execute.rs @@ -8,6 +8,7 @@ use std::collections::HashMap; use crate::jupyter::header::Header; use crate::jupyter::message::Message; use crate::jupyter::request::Request; +use bytes::Bytes; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -44,3 +45,15 @@ impl From for Request { Request::Execute(msg) } } + +#[derive(Deserialize, Debug)] +pub struct ExecuteReply { + status: String, + execution_count: u32, +} + +impl From for ExecuteReply { + fn from(bytes: Bytes) -> Self { + serde_json::from_slice(&bytes).expect("Failed to deserialize ExecuteReply") + } +} diff --git a/src/jupyter/response.rs b/src/jupyter/response.rs index 755c2f1..9a3c769 100644 --- a/src/jupyter/response.rs +++ b/src/jupyter/response.rs @@ -5,6 +5,7 @@ zeromq::ZmqMessage -> WireProtocol -> Response -> Message with Jupyter messag */ use crate::jupyter::header::Header; use crate::jupyter::message::Message; +use crate::jupyter::message_content::execute::ExecuteReply; use crate::jupyter::message_content::kernel_info::KernelInfoReply; use crate::jupyter::message_content::status::Status; use crate::jupyter::metadata::Metadata; @@ -20,6 +21,7 @@ pub struct UnmodeledContent(serde_json::Value); pub enum Response { Status(Message), KernelInfo(Message), + Execute(Message), Unmodeled(Message), } @@ -29,9 +31,20 @@ impl Response { match self { Response::Status(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(), Response::KernelInfo(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(), + Response::Execute(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(), Response::Unmodeled(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(), } } + + pub fn msg_type(&self) -> String { + // return msg_type from header + match self { + Response::Status(msg) => msg.header.msg_type.to_owned(), + Response::KernelInfo(msg) => msg.header.msg_type.to_owned(), + Response::Execute(msg) => msg.header.msg_type.to_owned(), + Response::Unmodeled(msg) => msg.header.msg_type.to_owned(), + } + } } impl From for Response { @@ -60,6 +73,16 @@ impl From for Response { }; Response::KernelInfo(msg) } + "execute_reply" => { + let content: ExecuteReply = wp.content.into(); + let msg: Message = Message { + header, + parent_header: Some(parent_header), + metadata: Some(metadata), + content, + }; + Response::Execute(msg) + } _ => { let content: UnmodeledContent = serde_json::from_slice(&wp.content) .expect("Error deserializing unmodeled content"); diff --git a/src/main.rs b/src/main.rs index 633f32f..d4a6787 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,6 +32,7 @@ async fn main() { let handler = DebugHandler::new(); let handlers = vec![Arc::new(handler) as Arc]; - let action = client.kernel_info_request(handlers).await; + // let action = client.kernel_info_request(handlers).await; + let action = client.execute_request("2 + 2".to_owned(), handlers).await; action.await; } diff --git a/tests/test_messages.rs b/tests/test_commands.rs similarity index 62% rename from tests/test_messages.rs rename to tests/test_commands.rs index d136dc8..807fc61 100644 --- a/tests/test_messages.rs +++ b/tests/test_commands.rs @@ -46,22 +46,14 @@ impl MessageCountHandler { impl Handler for MessageCountHandler { async fn handle(&self, msg: &Response) { let mut counts = self.counts.lock().await; - match msg { - Response::KernelInfo(_) => { - let count = counts.get("kernel_info").unwrap_or(&0) + 1; - counts.insert("kernel_info".to_string(), count); - } - Response::Status(_) => { - let count = counts.get("status").unwrap_or(&0) + 1; - counts.insert("status".to_string(), count); - } - - _ => {} - } + let msg_type = msg.msg_type(); + let count = counts.entry(msg_type).or_insert(0); + *count += 1; } } #[rstest::rstest] +#[serial_test::serial] #[tokio::test] async fn test_kernel_info(_ipykernel_process: Option) { let connection_info = ConnectionInfo::from_file("/tmp/kernel_sidecar_rs_test.json") @@ -75,7 +67,32 @@ async fn test_kernel_info(_ipykernel_process: Option) { action.await; let counts = handler.counts.lock().await; let mut expected = HashMap::new(); - expected.insert("kernel_info".to_string(), 1); + expected.insert("kernel_info_reply".to_string(), 1); + expected.insert("status".to_string(), 2); + assert_eq!(*counts, expected); +} + +#[rstest::rstest] +#[serial_test::serial] +#[tokio::test] +async fn test_execute_request(_ipykernel_process: Option) { + let connection_info = ConnectionInfo::from_file("/tmp/kernel_sidecar_rs_test.json") + .expect("Failed to read connection info from fixture"); + let client = Client::new(connection_info).await; + + // send execute_request + let handler = MessageCountHandler::new(); + let handlers = vec![Arc::new(handler.clone()) as Arc]; + let action = client + .execute_request("print('hello')".to_string(), handlers) + .await; + action.await; + let counts = handler.counts.lock().await; + let mut expected = HashMap::new(); + // status busy -> execute_input -> stream -> status idle & execute_reply expected.insert("status".to_string(), 2); + expected.insert("execute_input".to_string(), 1); + expected.insert("stream".to_string(), 1); + expected.insert("execute_reply".to_string(), 1); assert_eq!(*counts, expected); }