Skip to content

Commit

Permalink
working on execute action support
Browse files Browse the repository at this point in the history
  • Loading branch information
kafonek committed Nov 23, 2023
1 parent c248866 commit 94d2704
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ zeromq = "0.3.4"

[dev-dependencies]
rstest = "0.18.2"
serial_test = "2.0.0"
4 changes: 4 additions & 0 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ pub trait Handler: Debug + Send + Sync {
#[derive(Debug, PartialEq)]
pub enum ExpectedReplyType {
KernelInfo,
ExecuteReply,
None,
}

impl From<&Request> for ExpectedReplyType {
fn from(request: &Request) -> Self {
match request {
Request::KernelInfo(_) => ExpectedReplyType::KernelInfo,
Request::Execute(_) => ExpectedReplyType::ExecuteReply,
_ => ExpectedReplyType::None,

Check failure on line 30 in src/actions.rs

View workflow job for this annotation

GitHub Actions / test

unreachable pattern

Check failure on line 30 in src/actions.rs

View workflow job for this annotation

GitHub Actions / lint

unreachable pattern
}
}
Expand All @@ -34,6 +36,7 @@ impl From<&Response> for ExpectedReplyType {
fn from(response: &Response) -> Self {
match response {
Response::KernelInfo(_) => ExpectedReplyType::KernelInfo,
Response::Execute(_) => ExpectedReplyType::ExecuteReply,
_ => ExpectedReplyType::None,
}
}
Expand Down Expand Up @@ -85,6 +88,7 @@ impl Action {
let mut kernel_idle = false;
let mut expected_reply_seen = match expected_reply {
ExpectedReplyType::KernelInfo => false,
ExpectedReplyType::ExecuteReply => false,
ExpectedReplyType::None => true,
};
while let Some(response) = msg_rx.recv().await {
Expand Down
22 changes: 21 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use tokio::sync::{mpsc, Notify, RwLock};
use zeromq::{DealerSocket, Socket, SocketRecv, SocketSend, SubSocket, ZmqMessage};

use crate::actions::{Action, Handler};
use crate::jupyter::message_content::execute::ExecuteRequest;
use crate::jupyter::message_content::kernel_info::KernelInfoRequest;
use crate::jupyter::request::Request;
use crate::jupyter::response::Response;
Expand Down Expand Up @@ -151,6 +152,12 @@ impl Client {
let action = self.send_request(request.into(), handlers).await;
action

Check failure on line 153 in src/client.rs

View workflow job for this annotation

GitHub Actions / lint

returning the result of a `let` binding from a block
}

pub async fn execute_request(&self, code: String, handlers: Vec<Arc<dyn Handler>>) -> Action {
let request = ExecuteRequest::new(code);
let action = self.send_request(request.into(), handlers).await;
action

Check failure on line 159 in src/client.rs

View workflow job for this annotation

GitHub Actions / lint

returning the result of a `let` binding from a block
}
}

impl Drop for Client {
Expand All @@ -173,7 +180,20 @@ async fn process_message_worker(
let response: Response = zmq_msg.into();
let msg_id = response.msg_id();
if let Some(action) = actions.read().await.get(&msg_id) {
action.send(response).await.unwrap(); }
let sent = action.send(response).await;
// If we're seeing SendError here, it means we're still seeing ZMQ messages with
// parent header msg id matching a request / Action that is "completed" and has
// shut down its mpsc Receiver channel. That's probably happening because the
// Action is not configured to expect some Reply type and is "finishing" when
// Kernel status goes Idle but then we send along another Reply messages to a
// shutdown mpsc Receiver channel.
match sent {
Ok(_) => {},
Err(e) => {
dbg!(e);
}
}
}
},
_ = shutdown_signal.notified() => {
break;
Expand Down
13 changes: 13 additions & 0 deletions src/jupyter/message_content/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::collections::HashMap;
use crate::jupyter::header::Header;
use crate::jupyter::message::Message;
use crate::jupyter::request::Request;
use bytes::Bytes;
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -44,3 +45,15 @@ impl From<ExecuteRequest> for Request {
Request::Execute(msg)
}
}

#[derive(Deserialize, Debug)]
pub struct ExecuteReply {
status: String,

Check failure on line 51 in src/jupyter/message_content/execute.rs

View workflow job for this annotation

GitHub Actions / test

fields `status` and `execution_count` are never read

Check failure on line 51 in src/jupyter/message_content/execute.rs

View workflow job for this annotation

GitHub Actions / lint

fields `status` and `execution_count` are never read
execution_count: u32,
}

impl From<Bytes> for ExecuteReply {
fn from(bytes: Bytes) -> Self {
serde_json::from_slice(&bytes).expect("Failed to deserialize ExecuteReply")
}
}
23 changes: 23 additions & 0 deletions src/jupyter/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ zeromq::ZmqMessage -> WireProtocol -> Response -> Message<T> with Jupyter messag
*/
use crate::jupyter::header::Header;
use crate::jupyter::message::Message;
use crate::jupyter::message_content::execute::ExecuteReply;
use crate::jupyter::message_content::kernel_info::KernelInfoReply;
use crate::jupyter::message_content::status::Status;
use crate::jupyter::metadata::Metadata;
Expand All @@ -20,6 +21,7 @@ pub struct UnmodeledContent(serde_json::Value);
pub enum Response {
Status(Message<Status>),
KernelInfo(Message<KernelInfoReply>),
Execute(Message<ExecuteReply>),
Unmodeled(Message<UnmodeledContent>),
}

Expand All @@ -29,9 +31,20 @@ impl Response {
match self {
Response::Status(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(),
Response::KernelInfo(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(),
Response::Execute(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(),
Response::Unmodeled(msg) => msg.parent_header.as_ref().unwrap().msg_id.to_owned(),
}
}

pub fn msg_type(&self) -> String {
// return msg_type from header
match self {
Response::Status(msg) => msg.header.msg_type.to_owned(),
Response::KernelInfo(msg) => msg.header.msg_type.to_owned(),
Response::Execute(msg) => msg.header.msg_type.to_owned(),
Response::Unmodeled(msg) => msg.header.msg_type.to_owned(),
}
}
}

impl From<WireProtocol> for Response {
Expand Down Expand Up @@ -60,6 +73,16 @@ impl From<WireProtocol> for Response {
};
Response::KernelInfo(msg)
}
"execute_reply" => {
let content: ExecuteReply = wp.content.into();
let msg: Message<ExecuteReply> = Message {
header,
parent_header: Some(parent_header),
metadata: Some(metadata),
content,
};
Response::Execute(msg)
}
_ => {
let content: UnmodeledContent = serde_json::from_slice(&wp.content)
.expect("Error deserializing unmodeled content");
Expand Down
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async fn main() {

let handler = DebugHandler::new();
let handlers = vec![Arc::new(handler) as Arc<dyn Handler>];
let action = client.kernel_info_request(handlers).await;
// let action = client.kernel_info_request(handlers).await;
let action = client.execute_request("2 + 2".to_owned(), handlers).await;
action.await;
}
43 changes: 30 additions & 13 deletions tests/test_messages.rs → tests/test_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,14 @@ impl MessageCountHandler {
impl Handler for MessageCountHandler {
async fn handle(&self, msg: &Response) {
let mut counts = self.counts.lock().await;
match msg {
Response::KernelInfo(_) => {
let count = counts.get("kernel_info").unwrap_or(&0) + 1;
counts.insert("kernel_info".to_string(), count);
}
Response::Status(_) => {
let count = counts.get("status").unwrap_or(&0) + 1;
counts.insert("status".to_string(), count);
}

_ => {}
}
let msg_type = msg.msg_type();
let count = counts.entry(msg_type).or_insert(0);
*count += 1;
}
}

#[rstest::rstest]
#[serial_test::serial]
#[tokio::test]
async fn test_kernel_info(_ipykernel_process: Option<Child>) {
let connection_info = ConnectionInfo::from_file("/tmp/kernel_sidecar_rs_test.json")
Expand All @@ -75,7 +67,32 @@ async fn test_kernel_info(_ipykernel_process: Option<Child>) {
action.await;
let counts = handler.counts.lock().await;
let mut expected = HashMap::new();
expected.insert("kernel_info".to_string(), 1);
expected.insert("kernel_info_reply".to_string(), 1);
expected.insert("status".to_string(), 2);
assert_eq!(*counts, expected);
}

#[rstest::rstest]
#[serial_test::serial]
#[tokio::test]
async fn test_execute_request(_ipykernel_process: Option<Child>) {
let connection_info = ConnectionInfo::from_file("/tmp/kernel_sidecar_rs_test.json")
.expect("Failed to read connection info from fixture");
let client = Client::new(connection_info).await;

// send execute_request
let handler = MessageCountHandler::new();
let handlers = vec![Arc::new(handler.clone()) as Arc<dyn Handler>];
let action = client
.execute_request("print('hello')".to_string(), handlers)
.await;
action.await;
let counts = handler.counts.lock().await;
let mut expected = HashMap::new();
// status busy -> execute_input -> stream -> status idle & execute_reply
expected.insert("status".to_string(), 2);
expected.insert("execute_input".to_string(), 1);
expected.insert("stream".to_string(), 1);
expected.insert("execute_reply".to_string(), 1);
assert_eq!(*counts, expected);
}

0 comments on commit 94d2704

Please sign in to comment.