diff --git a/ngrok/src/internals/proto.rs b/ngrok/src/internals/proto.rs index 4f2e6e1..fb863d0 100644 --- a/ngrok/src/internals/proto.rs +++ b/ngrok/src/internals/proto.rs @@ -37,8 +37,9 @@ pub const STOP_REQ: StreamType = StreamType::clamp(5); pub const UPDATE_REQ: StreamType = StreamType::clamp(6); pub const BIND_LABELED_REQ: StreamType = StreamType::clamp(7); pub const SRV_INFO_REQ: StreamType = StreamType::clamp(8); +pub const STOP_TUNNEL_REQ: StreamType = StreamType::clamp(9); -pub const VERSION: &str = "2"; +pub const VERSION: &[&str] = &["3", "2"]; // integers in priority order /// An error that may have an ngrok error code. /// All ngrok error codes are documented at https://ngrok.com/docs/errors @@ -523,6 +524,19 @@ pub struct Update { pub permit_major_version: bool, } +/// A request from remote to stop a tunnel +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "PascalCase")] +pub struct StopTunnel { + /// The id of the tunnel to stop + #[serde(rename = "ClientID")] + pub client_id: String, + /// The message on why this tunnel was stopped + pub message: String, + /// An optional ngrok error code + pub error_code: String, +} + pub type UpdateResp = CommandResp; rpc_req!(Update, UpdateResp, UPDATE_REQ); diff --git a/ngrok/src/internals/raw_session.rs b/ngrok/src/internals/raw_session.rs index 2530519..18c75f4 100644 --- a/ngrok/src/internals/raw_session.rs +++ b/ngrok/src/internals/raw_session.rs @@ -65,17 +65,23 @@ use super::{ StartTunnelWithLabel, StartTunnelWithLabelResp, Stop, + StopTunnel, Unbind, UnbindResp, Update, PROXY_REQ, RESTART_REQ, STOP_REQ, + STOP_TUNNEL_REQ, UPDATE_REQ, VERSION, }, rpc::RpcRequest, }; +use crate::{ + tunnel::AcceptError::TunnelClosed, + Session, +}; /// Errors arising from tunneling protocol RPC calls. #[derive(Error, Debug)] @@ -145,6 +151,7 @@ pub struct RpcClient { pub struct IncomingStreams { runtime: Handle, handlers: CommandHandlers, + pub(crate) session: Option, accept: Box, } @@ -220,6 +227,7 @@ impl RawSession { incoming: IncomingStreams { runtime, handlers, + session: None, accept: Box::new(accept), }, }; @@ -298,7 +306,7 @@ impl RpcClient { let req = Auth { client_id: id.clone(), extra, - version: vec![VERSION.into()], + version: VERSION.iter().map(|&x| x.into()).collect(), }; let resp = self.rpc(req).await?; @@ -380,6 +388,26 @@ impl RpcClient { pub const NOT_IMPLEMENTED: &str = "the agent has not defined a callback for this operation"; +async fn read_req(stream: &mut TypedStream) -> Result> +where + T: DeserializeOwned + Debug + 'static, +{ + debug!("reading request from stream"); + let mut buf = vec![]; + let req = serde_json::from_value(loop { + let mut tmp = vec![0u8; 256]; + let bytes = stream.read(&mut tmp).await.map_err(Either::Left)?; + buf.extend_from_slice(&tmp[..bytes]); + + if let Ok(obj) = serde_json::from_slice::(&buf) { + break obj; + } + }) + .map_err(Either::Right)?; + debug!(?req, "read request from stream"); + Ok(req) +} + async fn handle_req( handler: Option>>, mut stream: TypedStream, @@ -388,20 +416,7 @@ where T: DeserializeOwned + Debug + 'static, { let res = async { - debug!("reading request from stream"); - let mut buf = vec![]; - let req = serde_json::from_value(loop { - let mut tmp = vec![0u8; 256]; - let bytes = stream.read(&mut tmp).await.map_err(Either::Left)?; - buf.extend_from_slice(&tmp[..bytes]); - - if let Ok(obj) = serde_json::from_slice::(&buf) { - break obj; - } - }) - .map_err(Either::Right)?; - debug!(?req, "read request from stream"); - + let req = read_req(&mut stream).await?; let resp = if let Some(handler) = handler { debug!("running command handler"); handler.handle_command(req).await.err() @@ -447,6 +462,24 @@ impl IncomingStreams { self.runtime .spawn(handle_req(self.handlers.on_stop.clone(), stream)); } + STOP_TUNNEL_REQ => { + // close the tunnel through the session + if let Some(session) = &self.session { + let req = + read_req::(&mut stream) + .await + .map_err(|e| match e { + Either::Left(err) => ReadHeaderError::from(err), + Either::Right(err) => ReadHeaderError::from(err), + })?; + session + .close_tunnel_with_error( + req.client_id, + TunnelClosed(req.message.clone(), req.error_code.clone()), + ) + .await; + } + } PROXY_REQ => { let header = ProxyHeader::read_from_stream(&mut *stream).await?; diff --git a/ngrok/src/session.rs b/ngrok/src/session.rs index 923be3b..9995047 100644 --- a/ngrok/src/session.rs +++ b/ngrok/src/session.rs @@ -17,7 +17,9 @@ use std::{ }; use arc_swap::ArcSwap; -use async_rustls::rustls::{self,}; +use async_rustls::rustls::{ + self, +}; use async_trait::async_trait; use bytes::Bytes; use futures::{ @@ -74,6 +76,7 @@ pub use crate::internals::{ CommandResp, Restart, Stop, + StopTunnel, Update, }, raw_session::{ @@ -657,21 +660,26 @@ impl SessionBuilder { /// an error. pub async fn connect(&self) -> Result { let (dropref, dropped) = awaitdrop::awaitdrop(); - let (inner, incoming) = self.connect_inner(None).await?; + let (inner, mut incoming) = self.connect_inner(None).await?; let rt = inner.runtime.clone(); let inner = Arc::new(ArcSwap::new(inner.into())); + let session = Session { + _dropref: dropref, + inner: inner.clone(), + }; + + // store the session for use with StopTunnel + incoming.session = Some(session.clone()); + rt.spawn(future::select( - accept_incoming(incoming, inner.clone()).boxed(), + accept_incoming(incoming, inner).boxed(), dropped.wait(), )); - Ok(Session { - _dropref: dropref, - inner, - }) + Ok(session) } pub(crate) fn get_or_create_tls_config(&self) -> rustls::ClientConfig { @@ -944,6 +952,16 @@ impl Session { Ok(tunnel) } + /// Close a tunnel with an error from the remote. + /// Skips the call to unlisten, since the remote has already rejected it. + pub(crate) async fn close_tunnel_with_error(&self, id: impl AsRef, err: AcceptError) { + let id = id.as_ref(); + let inner = self.inner.load(); + if let Some(tun) = inner.tunnels.write().await.remove(id) { + let _ = tun.tx.send(Err(err)).await; + }; + } + /// Close a tunnel with the given ID. pub async fn close_tunnel(&self, id: impl AsRef) -> Result<(), RpcError> { let id = id.as_ref(); diff --git a/ngrok/src/tunnel.rs b/ngrok/src/tunnel.rs index 1d4ea78..fbe9842 100644 --- a/ngrok/src/tunnel.rs +++ b/ngrok/src/tunnel.rs @@ -43,6 +43,9 @@ pub enum AcceptError { /// An error arose during reconnect #[error("reconnect error")] Reconnect(#[from] Arc), + /// The tunnel was closed. + #[error("tunnel closed")] + TunnelClosed(String, String), } #[derive(Clone)]