diff --git a/crates/ruff_server/src/message.rs b/crates/ruff_server/src/message.rs index 034771aea828ce..bd759feb1142ed 100644 --- a/crates/ruff_server/src/message.rs +++ b/crates/ruff_server/src/message.rs @@ -6,9 +6,9 @@ use crate::server::ClientSender; static MESSENGER: OnceLock = OnceLock::new(); -pub(crate) fn init_messenger(client_sender: &ClientSender) { +pub(crate) fn init_messenger(client_sender: ClientSender) { MESSENGER - .set(client_sender.clone()) + .set(client_sender) .expect("messenger should only be initialized once"); // unregister any previously registered panic hook diff --git a/crates/ruff_server/src/server.rs b/crates/ruff_server/src/server.rs index e9944db053c630..ca9a453042b58f 100644 --- a/crates/ruff_server/src/server.rs +++ b/crates/ruff_server/src/server.rs @@ -2,7 +2,6 @@ use std::num::NonZeroUsize; -use lsp::Connection; use lsp_server as lsp; use lsp_types as types; use types::ClientCapabilities; @@ -18,6 +17,8 @@ use types::TextDocumentSyncOptions; use types::WorkDoneProgressOptions; use types::WorkspaceFoldersServerCapabilities; +use self::connection::Connection; +use self::connection::ConnectionInitializer; use self::schedule::event_loop_thread; use self::schedule::Scheduler; use self::schedule::Task; @@ -28,34 +29,39 @@ use crate::PositionEncoding; mod api; mod client; +mod connection; mod schedule; -pub(crate) use client::ClientSender; +pub(crate) use connection::ClientSender; pub(crate) type Result = std::result::Result; pub struct Server { - conn: lsp::Connection, + connection: Connection, client_capabilities: ClientCapabilities, - threads: lsp::IoThreads, worker_threads: NonZeroUsize, session: Session, } impl Server { pub fn new(worker_threads: NonZeroUsize) -> crate::Result { - let (conn, threads) = lsp::Connection::stdio(); + let connection = ConnectionInitializer::stdio(); - crate::message::init_messenger(&conn.sender); - - let (id, params) = conn.initialize_start()?; - - let init_params: types::InitializeParams = serde_json::from_value(params)?; + let (id, init_params) = connection.initialize_start()?; let client_capabilities = init_params.capabilities; let position_encoding = Self::find_best_position_encoding(&client_capabilities); let server_capabilities = Self::server_capabilities(position_encoding); + let connection = connection.initialize_finish( + id, + &server_capabilities, + crate::SERVER_NAME, + crate::version(), + )?; + + crate::message::init_messenger(connection.make_sender()); + let AllSettings { global_settings, mut workspace_settings, @@ -86,19 +92,8 @@ impl Server { anyhow::anyhow!("Failed to get the current working directory while creating a default workspace.") })?; - let initialize_data = serde_json::json!({ - "capabilities": server_capabilities, - "serverInfo": { - "name": crate::SERVER_NAME, - "version": crate::version() - } - }); - - conn.initialize_finish(id, initialize_data)?; - Ok(Self { - conn, - threads, + connection, worker_threads, session: Session::new( &client_capabilities, @@ -111,17 +106,20 @@ impl Server { } pub fn run(self) -> crate::Result<()> { - let result = event_loop_thread(move || { + event_loop_thread(move || { Self::event_loop( - &self.conn, + &self.connection, &self.client_capabilities, self.session, self.worker_threads, - ) + )?; + self.connection.close()?; + // Note: when we start routing tracing through the LSP, + // this should be replaced with a log directly to `stderr`. + tracing::info!("Server has shut down successfully"); + Ok(()) })? - .join(); - self.threads.join()?; - result + .join() } #[allow(clippy::needless_pass_by_value)] // this is because we aren't using `next_request_id` yet. @@ -132,22 +130,21 @@ impl Server { worker_threads: NonZeroUsize, ) -> crate::Result<()> { let mut scheduler = - schedule::Scheduler::new(&mut session, worker_threads, &connection.sender); + schedule::Scheduler::new(&mut session, worker_threads, connection.make_sender()); Self::try_register_capabilities(client_capabilities, &mut scheduler); - for msg in &connection.receiver { + for msg in connection.incoming() { + if connection.handle_shutdown(&msg)? { + break; + } let task = match msg { - lsp::Message::Request(req) => { - if connection.handle_shutdown(&req)? { - return Ok(()); - } - api::request(req) - } + lsp::Message::Request(req) => api::request(req), lsp::Message::Notification(notification) => api::notification(notification), lsp::Message::Response(response) => scheduler.response(response), }; scheduler.dispatch(task); } + Ok(()) } diff --git a/crates/ruff_server/src/server/client.rs b/crates/ruff_server/src/server/client.rs index d36c50ef665f8d..bd12f86d78e5c9 100644 --- a/crates/ruff_server/src/server/client.rs +++ b/crates/ruff_server/src/server/client.rs @@ -4,9 +4,7 @@ use lsp_server::{Notification, RequestId}; use rustc_hash::FxHashMap; use serde_json::Value; -use super::schedule::Task; - -pub(crate) type ClientSender = crossbeam::channel::Sender; +use super::{schedule::Task, ClientSender}; type ResponseBuilder<'s> = Box Task<'s>>; @@ -29,12 +27,12 @@ pub(crate) struct Requester<'s> { } impl<'s> Client<'s> { - pub(super) fn new(sender: &ClientSender) -> Self { + pub(super) fn new(sender: ClientSender) -> Self { Self { notifier: Notifier(sender.clone()), responder: Responder(sender.clone()), requester: Requester { - sender: sender.clone(), + sender, next_request_id: 1, response_handlers: FxHashMap::default(), }, @@ -60,16 +58,15 @@ impl Notifier { let message = lsp_server::Message::Notification(Notification::new(method, params)); - Ok(self.0.send(message)?) + self.0.send(message) } pub(crate) fn notify_method(&self, method: String) -> crate::Result<()> { - Ok(self - .0 + self.0 .send(lsp_server::Message::Notification(Notification::new( method, Value::Null, - )))?) + ))) } } @@ -82,7 +79,7 @@ impl Responder { where R: serde::Serialize, { - Ok(self.0.send( + self.0.send( match result { Ok(res) => lsp_server::Response::new_ok(id, res), Err(crate::server::api::Error { code, error }) => { @@ -90,7 +87,7 @@ impl Responder { } } .into(), - )?) + ) } } diff --git a/crates/ruff_server/src/server/connection.rs b/crates/ruff_server/src/server/connection.rs new file mode 100644 index 00000000000000..c04567c57ae84c --- /dev/null +++ b/crates/ruff_server/src/server/connection.rs @@ -0,0 +1,144 @@ +use lsp_server as lsp; +use lsp_types::{notification::Notification, request::Request}; +use std::sync::{Arc, Weak}; + +type ConnectionSender = crossbeam::channel::Sender; +type ConnectionReceiver = crossbeam::channel::Receiver; + +/// A builder for `Connection` that handles LSP initialization. +pub(crate) struct ConnectionInitializer { + connection: lsp::Connection, + threads: lsp::IoThreads, +} + +/// Handles inbound and outbound messages with the client. +pub(crate) struct Connection { + sender: Arc, + receiver: ConnectionReceiver, + threads: lsp::IoThreads, +} + +impl ConnectionInitializer { + /// Create a new LSP server connection over stdin/stdout. + pub(super) fn stdio() -> Self { + let (connection, threads) = lsp::Connection::stdio(); + Self { + connection, + threads, + } + } + + /// Starts the initialization process with the client by listening for an initialization request. + /// Returns a request ID that should be passed into `initialize_finish` later, + /// along with the initialization parameters that were provided. + pub(super) fn initialize_start( + &self, + ) -> crate::Result<(lsp::RequestId, lsp_types::InitializeParams)> { + let (id, params) = self.connection.initialize_start()?; + Ok((id, serde_json::from_value(params)?)) + } + + /// Finishes the initialization process with the client, + /// returning an initialized `Connection`. + pub(super) fn initialize_finish( + self, + id: lsp::RequestId, + server_capabilities: &lsp_types::ServerCapabilities, + name: &str, + version: &str, + ) -> crate::Result { + self.connection.initialize_finish( + id, + serde_json::json!({ + "capabilities": server_capabilities, + "serverInfo": { + "name": name, + "version": version + } + }), + )?; + let Self { + connection: lsp::Connection { sender, receiver }, + threads, + } = self; + Ok(Connection { + sender: Arc::new(sender), + receiver, + threads, + }) + } +} + +impl Connection { + /// Make a new `ClientSender` for sending messages to the client. + pub(super) fn make_sender(&self) -> ClientSender { + ClientSender { + weak_sender: Arc::downgrade(&self.sender), + } + } + + /// An iterator over incoming messages from the client. + pub(super) fn incoming(&self) -> crossbeam::channel::Iter { + self.receiver.iter() + } + + /// Check and respond to any incoming shutdown requests; returns`true` if the server should be shutdown. + pub(super) fn handle_shutdown(&self, message: &lsp::Message) -> crate::Result { + match message { + lsp::Message::Request(lsp::Request { id, method, .. }) + if method == lsp_types::request::Shutdown::METHOD => + { + self.sender + .send(lsp::Response::new_ok(id.clone(), ()).into())?; + tracing::info!("Shutdown request received. Waiting for an exit notification..."); + match self.receiver.recv_timeout(std::time::Duration::from_secs(30))? { + lsp::Message::Notification(lsp::Notification { method, .. }) if method == lsp_types::notification::Exit::METHOD => { + tracing::info!("Exit notification received. Server shutting down..."); + Ok(true) + }, + message => anyhow::bail!("Server received unexpected message {message:?} while waiting for exit notification") + } + } + lsp::Message::Notification(lsp::Notification { method, .. }) + if method == lsp_types::notification::Exit::METHOD => + { + tracing::error!("Server received an exit notification before a shutdown request was sent. Exiting..."); + Ok(true) + } + _ => Ok(false), + } + } + + /// Join the I/O threads that underpin this connection. + /// This is guaranteed to be nearly immediate since + /// we close the only active channels to these threads prior + /// to joining them. + pub(super) fn close(self) -> crate::Result<()> { + std::mem::drop( + Arc::into_inner(self.sender) + .expect("the client sender shouldn't have more than one strong reference"), + ); + std::mem::drop(self.receiver); + self.threads.join()?; + Ok(()) + } +} + +/// A weak reference to an underlying sender channel, used for communication with the client. +/// If the `Connection` that created this `ClientSender` is dropped, any `send` calls will throw +/// an error. +#[derive(Clone, Debug)] +pub(crate) struct ClientSender { + weak_sender: Weak, +} + +// note: additional wrapper functions for senders may be implemented as needed. +impl ClientSender { + pub(crate) fn send(&self, msg: lsp::Message) -> crate::Result<()> { + let Some(sender) = self.weak_sender.upgrade() else { + anyhow::bail!("The connection with the client has been closed"); + }; + + Ok(sender.send(msg)?) + } +} diff --git a/crates/ruff_server/src/server/schedule.rs b/crates/ruff_server/src/server/schedule.rs index fe8cc5c18c4e0b..f03570686aa4a9 100644 --- a/crates/ruff_server/src/server/schedule.rs +++ b/crates/ruff_server/src/server/schedule.rs @@ -1,7 +1,5 @@ use std::num::NonZeroUsize; -use crossbeam::channel::Sender; - use crate::session::Session; mod task; @@ -14,7 +12,7 @@ use self::{ thread::ThreadPriority, }; -use super::client::Client; +use super::{client::Client, ClientSender}; /// The event loop thread is actually a secondary thread that we spawn from the /// _actual_ main thread. This secondary thread has a larger stack size @@ -45,7 +43,7 @@ impl<'s> Scheduler<'s> { pub(super) fn new( session: &'s mut Session, worker_threads: NonZeroUsize, - sender: &Sender, + sender: ClientSender, ) -> Self { const FMT_THREADS: usize = 1; Self {