diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 83517fc2387ff1..b96ab3a3cb99aa 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -16,9 +16,10 @@ use std::{ }, time::{Duration, Instant}, }; +use url::Url; use util::TryFutureExt; -use crate::transport::{StdioTransport, Transport}; +use crate::transport::{SseTransport, StdioTransport, Transport}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -127,6 +128,12 @@ struct Error { message: String, } +#[derive(Debug, Clone, Deserialize)] +pub enum ModelContextServer { + Binary(ModelContextServerBinary), + Endpoint(ModelContextServerEndpoint), +} + #[derive(Debug, Clone, Deserialize)] pub struct ModelContextServerBinary { pub executable: PathBuf, @@ -134,6 +141,11 @@ pub struct ModelContextServerBinary { pub env: Option>, } +#[derive(Debug, Clone, Deserialize)] +pub struct ModelContextServerEndpoint { + pub endpoint: Url, +} + impl Client { /// Creates a new Client instance for a context server. /// @@ -142,22 +154,43 @@ impl Client { /// It takes a server ID, binary information, and an async app context as input. pub fn new( server_id: ContextServerId, - binary: ModelContextServerBinary, + binary: ModelContextServer, cx: AsyncApp, ) -> Result { - log::info!( - "starting context server (executable={:?}, args={:?})", - binary.executable, - &binary.args - ); - - let server_name = binary - .executable - .file_name() - .map(|name| name.to_string_lossy().to_string()) - .unwrap_or_else(String::new); - - let transport = Arc::new(StdioTransport::new(binary, &cx)?); + let (server_name, transport): (String, Arc) = match binary { + ModelContextServer::Binary(binary) => { + log::info!( + "starting local context server (executable={:?}, args={:?})", + binary.executable, + &binary.args + ); + + let server_name = binary + .executable + .file_name() + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or_else(String::new); + + (server_name, Arc::new(StdioTransport::new(binary, &cx)?)) + } + ModelContextServer::Endpoint(endpoint) => { + log::info!( + "starting remote context server (endpoint={:?})", + endpoint.endpoint, + ); + + let server_name = endpoint + .endpoint + .host() + .map(|name| name.to_string()) + .unwrap_or_else(String::new); + + ( + server_name, + Arc::new(SseTransport::new(endpoint.endpoint, &cx)?), + ) + } + }; let (outbound_tx, outbound_rx) = channel::unbounded::(); let (output_done_tx, output_done_rx) = barrier::channel(); diff --git a/crates/context_server/src/manager.rs b/crates/context_server/src/manager.rs index 1441548b04820c..52d3ff3b16e6f2 100644 --- a/crates/context_server/src/manager.rs +++ b/crates/context_server/src/manager.rs @@ -63,15 +63,26 @@ impl ContextServer { pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { log::info!("starting context server {}", self.id); - let Some(command) = &self.config.command else { - bail!("no command specified for server {}", self.id); - }; + let client = Client::new( client::ContextServerId(self.id.clone()), - client::ModelContextServerBinary { - executable: Path::new(&command.path).to_path_buf(), - args: command.args.clone(), - env: command.env.clone(), + match &*self.config { + ServerConfig::Stdio { + command: Some(command), + settings: _, + } => client::ModelContextServer::Binary(client::ModelContextServerBinary { + executable: Path::new(&command.path).to_path_buf(), + args: command.args.clone(), + env: command.env.clone(), + }), + ServerConfig::Sse { endpoint } => { + client::ModelContextServer::Endpoint(client::ModelContextServerEndpoint { + endpoint: endpoint.parse()?, + }) + } + _ => { + bail!("invalid context server configuration") + } }, cx.clone(), )?; @@ -233,11 +244,18 @@ impl ContextServerManager { for (id, factory) in registry.read_with(&cx, |registry, _| registry.context_server_factories())? { - let config = desired_servers.entry(id).or_default(); - if config.command.is_none() { - if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() { - config.command = Some(extension_command); + let config = desired_servers.entry(id.clone()).or_default(); + match config { + ServerConfig::Stdio { command, .. } => { + if command.is_none() { + if let Some(extension_command) = + factory(project.clone(), &cx).await.log_err() + { + *command = Some(extension_command); + } + } } + ServerConfig::Sse { .. } => {} } } diff --git a/crates/context_server/src/transport.rs b/crates/context_server/src/transport.rs index b4f56b0ef03ac6..b3170f4184d945 100644 --- a/crates/context_server/src/transport.rs +++ b/crates/context_server/src/transport.rs @@ -1,3 +1,4 @@ +mod sse_transport; mod stdio_transport; use std::pin::Pin; @@ -6,6 +7,7 @@ use anyhow::Result; use async_trait::async_trait; use futures::Stream; +pub use sse_transport::*; pub use stdio_transport::*; #[async_trait] diff --git a/crates/context_server/src/transport/sse_transport.rs b/crates/context_server/src/transport/sse_transport.rs new file mode 100644 index 00000000000000..d82705a2dccee6 --- /dev/null +++ b/crates/context_server/src/transport/sse_transport.rs @@ -0,0 +1,143 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use futures::FutureExt; +use futures::{io::BufReader, AsyncBufReadExt as _, Stream}; +use gpui::http_client::HttpClient; +use gpui::{AsyncApp, BackgroundExecutor}; +use smol::channel; +use smol::lock::Mutex; +use url::Url; +use util::ResultExt as _; + +use crate::transport::Transport; + +struct MessageUrl { + url: Arc>>, + url_received: channel::Receiver<()>, +} + +impl MessageUrl { + fn new() -> (Self, channel::Sender<()>) { + let (url_sender, url_received) = channel::bounded::<()>(1); + ( + Self { + url: Arc::new(Mutex::new(None)), + url_received, + }, + url_sender, + ) + } + + async fn url(&self) -> Result { + if let Some(url) = self.url.lock().await.clone() { + return Ok(url); + } + self.url_received.recv().await?; + Ok(self.url.lock().await.clone().unwrap()) + } +} + +pub struct SseTransport { + message_url: MessageUrl, + stdin_receiver: channel::Receiver, + stderr_receiver: channel::Receiver, + http_client: Arc, +} + +impl SseTransport { + pub fn new(endpoint: Url, cx: &AsyncApp) -> Result { + let (stdin_sender, stdin_receiver) = channel::unbounded::(); + let (_stderr_sender, stderr_receiver) = channel::unbounded::(); + let (message_url, url_sender) = MessageUrl::new(); + let http_client = cx.update(|cx| cx.http_client().clone())?; + + let message_url_clone = message_url.url.clone(); + cx.spawn({ + let http_client = http_client.clone(); + move |cx| async move { + Self::handle_sse_stream( + cx.background_executor(), + endpoint, + message_url_clone, + stdin_sender, + url_sender, + http_client, + ) + .await + .log_err() + } + }) + .detach(); + + Ok(Self { + message_url, + stdin_receiver, + stderr_receiver, + http_client, + }) + } + + async fn handle_sse_stream( + executor: &BackgroundExecutor, + endpoint: Url, + message_url: Arc>>, + stdin_sender: channel::Sender, + url_sender: channel::Sender<()>, + http_client: Arc, + ) -> Result<()> { + loop { + let mut response = http_client + .get(endpoint.as_str(), Default::default(), true) + .await?; + let mut reader = BufReader::new(response.body_mut()); + let mut line = String::new(); + + loop { + futures::select! { + result = reader.read_line(&mut line).fuse() => { + match result { + Ok(0) => break, + Ok(_) => { + if line.starts_with("data: ") { + let data = line.trim_start_matches("data: "); + if data.starts_with("http") { + *message_url.lock().await = Some(data.trim().to_string()); + url_sender.send(()).await?; + } else { + stdin_sender.send(data.to_string()).await?; + } + } + line.clear(); + }, + Err(_) => break, + } + }, + _ = executor.timer(Duration::from_secs(30)).fuse() => { + break; + } + } + } + } + } +} + +#[async_trait] +impl Transport for SseTransport { + async fn send(&self, message: String) -> Result<()> { + let url = self.message_url.url().await?; + self.http_client.post_json(&url, message.into()).await?; + Ok(()) + } + + fn receive(&self) -> Pin + Send>> { + Box::pin(self.stdin_receiver.clone()) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(self.stderr_receiver.clone()) + } +} diff --git a/crates/context_server_settings/src/context_server_settings.rs b/crates/context_server_settings/src/context_server_settings.rs index d91a15ecfb4b55..1376778cd395b1 100644 --- a/crates/context_server_settings/src/context_server_settings.rs +++ b/crates/context_server_settings/src/context_server_settings.rs @@ -12,18 +12,34 @@ pub fn init(cx: &mut App) { ContextServerSettings::register(cx); } -#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)] -pub struct ServerConfig { - /// The command to run this context server. - /// - /// This will override the command set by an extension. - pub command: Option, - /// The settings for this context server. - /// - /// Consult the documentation for the context server to see what settings - /// are supported. - #[schemars(schema_with = "server_config_settings_json_schema")] - pub settings: Option, +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum ServerConfig { + Stdio { + /// The command to run this context server. + /// + /// This will override the command set by an extension. + command: Option, + /// The settings for this context server. + /// + /// Consult the documentation for the context server to see what settings + /// are supported. + #[schemars(schema_with = "server_config_settings_json_schema")] + settings: Option, + }, + Sse { + /// The remote SSE endpoint. + endpoint: String, + }, +} + +impl Default for ServerConfig { + fn default() -> Self { + ServerConfig::Stdio { + command: None, + settings: None, + } + } } fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema { diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_3_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_3_0.rs index b634134a6e324b..34b2ebf332b90b 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_3_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_3_0.rs @@ -7,7 +7,7 @@ use anyhow::{anyhow, bail, Context, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use async_trait::async_trait; -use context_server_settings::ContextServerSettings; +use context_server_settings::{ContextServerSettings, ServerConfig}; use extension::{ ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate, }; @@ -664,14 +664,21 @@ impl ExtensionImports for WasmState { }) .cloned() .unwrap_or_default(); - Ok(serde_json::to_string(&settings::ContextServerSettings { - command: settings.command.map(|command| settings::CommandSettings { - path: Some(command.path), - arguments: Some(command.args), - env: command.env.map(|env| env.into_iter().collect()), - }), - settings: settings.settings, - })?) + match settings { + ServerConfig::Stdio { command, settings } => { + Ok(serde_json::to_string(&settings::ContextServerSettings { + command: command.map(|command| settings::CommandSettings { + path: Some(command.path), + arguments: Some(command.args), + env: command.env.map(|env| env.into_iter().collect()), + }), + settings, + })?) + } + ServerConfig::Sse { .. } => { + bail!("SSE server configuration is not supported") + } + } } _ => { bail!("Unknown settings category: {}", category);