diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 1fed1e0a25c8..3f953ab6adea 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -1193,112 +1193,125 @@ impl JrMessageHandler for GooseAcpHandler { "goose-acp" } - async fn handle_message( + fn handle_message( &mut self, message: MessageCx, cx: JrConnectionCx, - ) -> Result, sacp::Error> { + ) -> impl std::future::Future, sacp::Error>> + Send { use sacp::util::MatchMessageFrom; use sacp::JrRequestCx; - MatchMessageFrom::new(message, &cx) - .if_request( - |req: InitializeRequest, req_cx: JrRequestCx| async { - req_cx.respond(self.agent.on_initialize(req).await?) - }, - ) - .await - .if_request( - |_req: AuthenticateRequest, req_cx: JrRequestCx| async { - req_cx.respond(AuthenticateResponse::new()) - }, - ) - .await - .if_request( - |req: NewSessionRequest, req_cx: JrRequestCx| async { - req_cx.respond(self.agent.on_new_session(req).await?) - }, - ) - .await - .if_request( - |req: LoadSessionRequest, req_cx: JrRequestCx| async { - req_cx.respond(self.agent.on_load_session(req, &cx).await?) - }, - ) - .await - .if_request( - |req: PromptRequest, req_cx: JrRequestCx| async { - let agent = self.agent.clone(); - let cx_clone = cx.clone(); - cx.spawn(async move { - match agent.on_prompt(req, &cx_clone).await { - Ok(response) => { - req_cx.respond(response)?; - } - Err(e) => { - req_cx.respond_with_error(e)?; + let agent = self.agent.clone(); + + // The MatchMessageFrom chain produces an ~85KB async state machine. + // Box::pin moves it to the heap so it doesn't overflow the tokio worker stack. + Box::pin(async move { + MatchMessageFrom::new(message, &cx) + .if_request( + |req: InitializeRequest, req_cx: JrRequestCx| async { + req_cx.respond(agent.on_initialize(req).await?) + }, + ) + .await + .if_request( + |_req: AuthenticateRequest, req_cx: JrRequestCx| async { + req_cx.respond(AuthenticateResponse::new()) + }, + ) + .await + .if_request( + |req: NewSessionRequest, req_cx: JrRequestCx| async { + req_cx.respond(agent.on_new_session(req).await?) + }, + ) + .await + .if_request( + |req: LoadSessionRequest, req_cx: JrRequestCx| async { + req_cx.respond(agent.on_load_session(req, &cx).await?) + }, + ) + .await + .if_request( + |req: PromptRequest, req_cx: JrRequestCx| async { + let agent = agent.clone(); + let cx_clone = cx.clone(); + cx.spawn(async move { + match agent.on_prompt(req, &cx_clone).await { + Ok(response) => { + req_cx.respond(response)?; + } + Err(e) => { + req_cx.respond_with_error(e)?; + } } - } - Ok(()) - })?; - Ok(()) - }, - ) - .await - .if_notification(|notif: CancelNotification| async { - self.agent.on_cancel(notif).await - }) - .await - // Handle methods not yet in the sacp typed API. - // - session/set_model: typed support pending in sacp - // - _: custom requests that will eventually route to goose-server - .otherwise({ - let agent = self.agent.clone(); - |message: MessageCx| async move { - match message { - MessageCx::Request(req, request_cx) - if req.method == "session/set_model" => - { - let params: SetSessionModelRequest = serde_json::from_value(req.params) - .map_err(|e| sacp::Error::invalid_params().data(e.to_string()))?; - let resp = agent - .on_set_model(¶ms.session_id.0, ¶ms.model_id.0) - .await?; - let json = serde_json::to_value(resp) - .map_err(|e| sacp::Error::internal_error().data(e.to_string()))?; - request_cx.respond(json)?; Ok(()) - } - MessageCx::Request(req, request_cx) if req.method.starts_with('_') => { - match agent.handle_custom_request(&req.method, req.params).await { - Ok(json) => request_cx.respond(json)?, - Err(e) => request_cx.respond_with_error(e)?, + })?; + Ok(()) + }, + ) + .await + .if_notification(|notif: CancelNotification| async { agent.on_cancel(notif).await }) + .await + // Handle methods not yet in the sacp typed API. + // - session/set_model: typed support pending in sacp + // - _: custom requests that will eventually route to goose-server + .otherwise({ + let agent = agent.clone(); + |message: MessageCx| async move { + match message { + MessageCx::Request(req, request_cx) + if req.method == "session/set_model" => + { + let params: SetSessionModelRequest = + serde_json::from_value(req.params).map_err(|e| { + sacp::Error::invalid_params().data(e.to_string()) + })?; + let resp = agent + .on_set_model(¶ms.session_id.0, ¶ms.model_id.0) + .await?; + let json = serde_json::to_value(resp).map_err(|e| { + sacp::Error::internal_error().data(e.to_string()) + })?; + request_cx.respond(json)?; + Ok(()) } - Ok(()) + MessageCx::Request(req, request_cx) if req.method.starts_with('_') => { + match agent.handle_custom_request(&req.method, req.params).await { + Ok(json) => request_cx.respond(json)?, + Err(e) => request_cx.respond_with_error(e)?, + } + Ok(()) + } + _ => Err(sacp::Error::method_not_found()), } - _ => Err(sacp::Error::method_not_found()), } - } - }) - .await - .map(|()| Handled::Yes) + }) + .await + .map(|()| Handled::Yes) + }) } } -pub async fn serve(agent: Arc, read: R, write: W) -> Result<()> +pub fn serve( + agent: Arc, + read: R, + write: W, +) -> std::pin::Pin> + Send>> where R: futures::AsyncRead + Unpin + Send + 'static, W: futures::AsyncWrite + Unpin + Send + 'static, { - let handler = GooseAcpHandler { agent }; + Box::pin(async move { + let handler = GooseAcpHandler { agent }; - AgentToClient::builder() - .name("goose-acp") - .with_handler(handler) - .serve(ByteStreams::new(write, read)) - .await?; + AgentToClient::builder() + .name("goose-acp") + .with_handler(handler) + .serve(ByteStreams::new(write, read)) + .await?; - Ok(()) + Ok(()) + }) } pub async fn run(builtins: Vec) -> Result<()> { diff --git a/crates/goose-acp/src/transport/http.rs b/crates/goose-acp/src/transport/http.rs index 0c1e7f28cca0..84bf195a7e16 100644 --- a/crates/goose-acp/src/transport/http.rs +++ b/crates/goose-acp/src/transport/http.rs @@ -41,13 +41,11 @@ impl HttpState { let acp_session_id = uuid::Uuid::new_v4().to_string(); + let read_stream = ReceiverToAsyncRead::new(to_agent_rx); + let write_stream = SenderToAsyncWrite::new(from_agent_tx); + let fut = crate::server::serve(agent, read_stream.compat(), write_stream.compat_write()); let handle = tokio::spawn(async move { - let read_stream = ReceiverToAsyncRead::new(to_agent_rx); - let write_stream = SenderToAsyncWrite::new(from_agent_tx); - - if let Err(e) = - crate::server::serve(agent, read_stream.compat(), write_stream.compat_write()).await - { + if let Err(e) = fut.await { error!("ACP session error: {}", e); } }); diff --git a/crates/goose-acp/src/transport/websocket.rs b/crates/goose-acp/src/transport/websocket.rs index 559375507e4a..6643e1703612 100644 --- a/crates/goose-acp/src/transport/websocket.rs +++ b/crates/goose-acp/src/transport/websocket.rs @@ -36,13 +36,11 @@ impl WsState { let acp_session_id = uuid::Uuid::new_v4().to_string(); + let read_stream = ReceiverToAsyncRead::new(to_agent_rx); + let write_stream = SenderToAsyncWrite::new(from_agent_tx); + let fut = crate::server::serve(agent, read_stream.compat(), write_stream.compat_write()); let handle = tokio::spawn(async move { - let read_stream = ReceiverToAsyncRead::new(to_agent_rx); - let write_stream = SenderToAsyncWrite::new(from_agent_tx); - - if let Err(e) = - crate::server::serve(agent, read_stream.compat(), write_stream.compat_write()).await - { + if let Err(e) = fut.await { error!("ACP WebSocket session error: {}", e); } });