Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

context_server: Add support for SSE MCP servers #25693

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions crates/context_server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ use std::{
},
time::{Duration, Instant},
};
use url::Url;
use util::TryFutureExt;

use crate::transport::{StdioTransport, Transport};
use crate::transport::{SseTransport, StdioTransport, Transport};

const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
Expand Down Expand Up @@ -127,13 +128,24 @@ struct Error {
message: String,
}

#[derive(Debug, Clone, Deserialize)]
pub enum ModelContextServer {
Binary(ModelContextServerBinary),
Endpoint(ModelContextServerEndpoint),
}

#[derive(Debug, Clone, Deserialize)]
pub struct ModelContextServerBinary {
pub executable: PathBuf,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ModelContextServerEndpoint {
pub endpoint: Url,
}

impl Client {
/// Creates a new Client instance for a context server.
///
Expand All @@ -142,22 +154,43 @@ impl Client {
/// It takes a server ID, binary information, and an async app context as input.
pub fn new(
server_id: ContextServerId,
binary: ModelContextServerBinary,
binary: ModelContextServer,
cx: AsyncApp,
) -> Result<Self> {
log::info!(
"starting context server (executable={:?}, args={:?})",
binary.executable,
&binary.args
);

let server_name = binary
.executable
.file_name()
.map(|name| name.to_string_lossy().to_string())
.unwrap_or_else(String::new);

let transport = Arc::new(StdioTransport::new(binary, &cx)?);
let (server_name, transport): (String, Arc<dyn Transport>) = match binary {
ModelContextServer::Binary(binary) => {
log::info!(
"starting local context server (executable={:?}, args={:?})",
binary.executable,
&binary.args
);

let server_name = binary
.executable
.file_name()
.map(|name| name.to_string_lossy().to_string())
.unwrap_or_else(String::new);

(server_name, Arc::new(StdioTransport::new(binary, &cx)?))
}
ModelContextServer::Endpoint(endpoint) => {
log::info!(
"starting remote context server (endpoint={:?})",
endpoint.endpoint,
);

let server_name = endpoint
.endpoint
.host()
.map(|name| name.to_string())
.unwrap_or_else(String::new);

(
server_name,
Arc::new(SseTransport::new(endpoint.endpoint, &cx)?),
)
}
};

let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
Expand Down
40 changes: 29 additions & 11 deletions crates/context_server/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,26 @@ impl ContextServer {

pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
log::info!("starting context server {}", self.id);
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};

let client = Client::new(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
match &*self.config {
ServerConfig::Stdio {
command: Some(command),
settings: _,
} => client::ModelContextServer::Binary(client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
}),
ServerConfig::Sse { endpoint } => {
client::ModelContextServer::Endpoint(client::ModelContextServerEndpoint {
endpoint: endpoint.parse()?,
})
}
_ => {
bail!("invalid context server configuration")
}
},
cx.clone(),
)?;
Expand Down Expand Up @@ -233,11 +244,18 @@ impl ContextServerManager {
for (id, factory) in
registry.read_with(&cx, |registry, _| registry.context_server_factories())?
{
let config = desired_servers.entry(id).or_default();
if config.command.is_none() {
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
config.command = Some(extension_command);
let config = desired_servers.entry(id.clone()).or_default();
match config {
ServerConfig::Stdio { command, .. } => {
if command.is_none() {
if let Some(extension_command) =
factory(project.clone(), &cx).await.log_err()
{
*command = Some(extension_command);
}
}
}
ServerConfig::Sse { .. } => {}
Comment on lines +247 to +258
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure what we are trying to do in this section, so I simply implemented a no-op, but I'm sure there's more to it... 😬

}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/context_server/src/transport.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod sse_transport;
mod stdio_transport;

use std::pin::Pin;
Expand All @@ -6,6 +7,7 @@ use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;

pub use sse_transport::*;
pub use stdio_transport::*;

#[async_trait]
Expand Down
143 changes: 143 additions & 0 deletions crates/context_server/src/transport/sse_transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use async_trait::async_trait;
use futures::FutureExt;
use futures::{io::BufReader, AsyncBufReadExt as _, Stream};
use gpui::http_client::HttpClient;
use gpui::{AsyncApp, BackgroundExecutor};
use smol::channel;
use smol::lock::Mutex;
use url::Url;
use util::ResultExt as _;

use crate::transport::Transport;

struct MessageUrl {
url: Arc<Mutex<Option<String>>>,
url_received: channel::Receiver<()>,
}

impl MessageUrl {
fn new() -> (Self, channel::Sender<()>) {
let (url_sender, url_received) = channel::bounded::<()>(1);
(
Self {
url: Arc::new(Mutex::new(None)),
url_received,
},
url_sender,
)
}

async fn url(&self) -> Result<String> {
if let Some(url) = self.url.lock().await.clone() {
return Ok(url);
}
self.url_received.recv().await?;
Ok(self.url.lock().await.clone().unwrap())
}
}

pub struct SseTransport {
message_url: MessageUrl,
stdin_receiver: channel::Receiver<String>,
stderr_receiver: channel::Receiver<String>,
http_client: Arc<dyn HttpClient>,
}

impl SseTransport {
pub fn new(endpoint: Url, cx: &AsyncApp) -> Result<Self> {
let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
let (_stderr_sender, stderr_receiver) = channel::unbounded::<String>();
let (message_url, url_sender) = MessageUrl::new();
let http_client = cx.update(|cx| cx.http_client().clone())?;

let message_url_clone = message_url.url.clone();
cx.spawn({
let http_client = http_client.clone();
move |cx| async move {
Self::handle_sse_stream(
cx.background_executor(),
endpoint,
message_url_clone,
stdin_sender,
url_sender,
http_client,
)
.await
.log_err()
}
})
.detach();

Ok(Self {
message_url,
stdin_receiver,
stderr_receiver,
http_client,
})
}

async fn handle_sse_stream(
executor: &BackgroundExecutor,
endpoint: Url,
message_url: Arc<Mutex<Option<String>>>,
stdin_sender: channel::Sender<String>,
url_sender: channel::Sender<()>,
http_client: Arc<dyn HttpClient>,
) -> Result<()> {
loop {
let mut response = http_client
.get(endpoint.as_str(), Default::default(), true)
.await?;
let mut reader = BufReader::new(response.body_mut());
let mut line = String::new();

loop {
futures::select! {
result = reader.read_line(&mut line).fuse() => {
match result {
Ok(0) => break,
Ok(_) => {
if line.starts_with("data: ") {
let data = line.trim_start_matches("data: ");
if data.starts_with("http") {
*message_url.lock().await = Some(data.trim().to_string());
url_sender.send(()).await?;
} else {
stdin_sender.send(data.to_string()).await?;
}
}
line.clear();
},
Err(_) => break,
}
},
_ = executor.timer(Duration::from_secs(30)).fuse() => {
break;
}
Comment on lines +119 to +121
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm trying to account for computer sleep/wake-up, but it doesn't always work as expected. Is there a better way to handle these interruptions?

}
}
}
}
}

#[async_trait]
impl Transport for SseTransport {
async fn send(&self, message: String) -> Result<()> {
let url = self.message_url.url().await?;
self.http_client.post_json(&url, message.into()).await?;
Ok(())
}

fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(self.stdin_receiver.clone())
}

fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(self.stderr_receiver.clone())
}
}
40 changes: 28 additions & 12 deletions crates/context_server_settings/src/context_server_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,34 @@ pub fn init(cx: &mut App) {
ContextServerSettings::register(cx);
}

#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
pub struct ServerConfig {
/// The command to run this context server.
///
/// This will override the command set by an extension.
pub command: Option<ServerCommand>,
/// The settings for this context server.
///
/// Consult the documentation for the context server to see what settings
/// are supported.
#[schemars(schema_with = "server_config_settings_json_schema")]
pub settings: Option<serde_json::Value>,
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum ServerConfig {
Stdio {
/// The command to run this context server.
///
/// This will override the command set by an extension.
command: Option<ServerCommand>,
/// The settings for this context server.
///
/// Consult the documentation for the context server to see what settings
/// are supported.
#[schemars(schema_with = "server_config_settings_json_schema")]
settings: Option<serde_json::Value>,
},
Sse {
/// The remote SSE endpoint.
endpoint: String,
},
}
Comment on lines +17 to +34
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially tried to go the untagged way, but I had deserialization problems, always returning an invalid configuration. Nevertheless, I don't believe this being the ideal way, as it's not backward compatible. Eager to hear what's the best approach


impl Default for ServerConfig {
fn default() -> Self {
ServerConfig::Stdio {
command: None,
settings: None,
}
}
}

fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
Expand Down
25 changes: 16 additions & 9 deletions crates/extension_host/src/wasm_host/wit/since_v0_3_0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::{anyhow, bail, Context, Result};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use async_trait::async_trait;
use context_server_settings::ContextServerSettings;
use context_server_settings::{ContextServerSettings, ServerConfig};
use extension::{
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
};
Expand Down Expand Up @@ -664,14 +664,21 @@ impl ExtensionImports for WasmState {
})
.cloned()
.unwrap_or_default();
Ok(serde_json::to_string(&settings::ContextServerSettings {
command: settings.command.map(|command| settings::CommandSettings {
path: Some(command.path),
arguments: Some(command.args),
env: command.env.map(|env| env.into_iter().collect()),
}),
settings: settings.settings,
})?)
match settings {
ServerConfig::Stdio { command, settings } => {
Ok(serde_json::to_string(&settings::ContextServerSettings {
command: command.map(|command| settings::CommandSettings {
path: Some(command.path),
arguments: Some(command.args),
env: command.env.map(|env| env.into_iter().collect()),
}),
settings,
})?)
}
ServerConfig::Sse { .. } => {
bail!("SSE server configuration is not supported")
}
}
Comment on lines +667 to +681
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I didn't add support in the extension, because I didn't want to mess around with the API. It seems that since_v0_2_0 is using since_v0_3_0 (under the name latest). Also this depends on the general settings

}
_ => {
bail!("Unknown settings category: {}", category);
Expand Down
Loading