Skip to content

Commit

Permalink
Handle StopTunnel message
Browse files Browse the repository at this point in the history
  • Loading branch information
bobzilladev committed Nov 6, 2023
1 parent 40a99c4 commit 650ed9a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 23 deletions.
16 changes: 15 additions & 1 deletion ngrok/src/internals/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ pub const STOP_REQ: StreamType = StreamType::clamp(5);
pub const UPDATE_REQ: StreamType = StreamType::clamp(6);
pub const BIND_LABELED_REQ: StreamType = StreamType::clamp(7);
pub const SRV_INFO_REQ: StreamType = StreamType::clamp(8);
pub const STOP_TUNNEL_REQ: StreamType = StreamType::clamp(9);

pub const VERSION: &str = "2";
pub const VERSION: &[&str] = &["3", "2"]; // integers in priority order

/// An error that may have an ngrok error code.
/// All ngrok error codes are documented at https://ngrok.com/docs/errors
Expand Down Expand Up @@ -523,6 +524,19 @@ pub struct Update {
pub permit_major_version: bool,
}

/// A request from remote to stop a tunnel
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(rename_all = "PascalCase")]
pub struct StopTunnel {
/// The id of the tunnel to stop
#[serde(rename = "ClientID")]
pub client_id: String,
/// The message on why this tunnel was stopped
pub message: String,
/// An optional ngrok error code
pub error_code: String,
}

pub type UpdateResp = CommandResp;
rpc_req!(Update, UpdateResp, UPDATE_REQ);

Expand Down
63 changes: 48 additions & 15 deletions ngrok/src/internals/raw_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,23 @@ use super::{
StartTunnelWithLabel,
StartTunnelWithLabelResp,
Stop,
StopTunnel,
Unbind,
UnbindResp,
Update,
PROXY_REQ,
RESTART_REQ,
STOP_REQ,
STOP_TUNNEL_REQ,
UPDATE_REQ,
VERSION,
},
rpc::RpcRequest,
};
use crate::{
tunnel::AcceptError::TunnelClosed,
Session,
};

/// Errors arising from tunneling protocol RPC calls.
#[derive(Error, Debug)]
Expand Down Expand Up @@ -145,6 +151,7 @@ pub struct RpcClient {
pub struct IncomingStreams {
runtime: Handle,
handlers: CommandHandlers,
pub(crate) session: Option<Session>,
accept: Box<dyn TypedAccept + Send>,
}

Expand Down Expand Up @@ -220,6 +227,7 @@ impl RawSession {
incoming: IncomingStreams {
runtime,
handlers,
session: None,
accept: Box::new(accept),
},
};
Expand Down Expand Up @@ -298,7 +306,7 @@ impl RpcClient {
let req = Auth {
client_id: id.clone(),
extra,
version: vec![VERSION.into()],
version: VERSION.iter().map(|&x| x.into()).collect(),
};

let resp = self.rpc(req).await?;
Expand Down Expand Up @@ -380,6 +388,26 @@ impl RpcClient {

pub const NOT_IMPLEMENTED: &str = "the agent has not defined a callback for this operation";

async fn read_req<T>(stream: &mut TypedStream) -> Result<T, Either<io::Error, serde_json::Error>>
where
T: DeserializeOwned + Debug + 'static,
{
debug!("reading request from stream");
let mut buf = vec![];
let req = serde_json::from_value(loop {
let mut tmp = vec![0u8; 256];
let bytes = stream.read(&mut tmp).await.map_err(Either::Left)?;
buf.extend_from_slice(&tmp[..bytes]);

if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(&buf) {
break obj;
}
})
.map_err(Either::Right)?;
debug!(?req, "read request from stream");
Ok(req)
}

async fn handle_req<T>(
handler: Option<Arc<dyn CommandHandler<T>>>,
mut stream: TypedStream,
Expand All @@ -388,20 +416,7 @@ where
T: DeserializeOwned + Debug + 'static,
{
let res = async {
debug!("reading request from stream");
let mut buf = vec![];
let req = serde_json::from_value(loop {
let mut tmp = vec![0u8; 256];
let bytes = stream.read(&mut tmp).await.map_err(Either::Left)?;
buf.extend_from_slice(&tmp[..bytes]);

if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(&buf) {
break obj;
}
})
.map_err(Either::Right)?;
debug!(?req, "read request from stream");

let req = read_req(&mut stream).await?;
let resp = if let Some(handler) = handler {
debug!("running command handler");
handler.handle_command(req).await.err()
Expand Down Expand Up @@ -447,6 +462,24 @@ impl IncomingStreams {
self.runtime
.spawn(handle_req(self.handlers.on_stop.clone(), stream));
}
STOP_TUNNEL_REQ => {
// close the tunnel through the session
if let Some(session) = &self.session {
let req =
read_req::<StopTunnel>(&mut stream)
.await
.map_err(|e| match e {
Either::Left(err) => ReadHeaderError::from(err),
Either::Right(err) => ReadHeaderError::from(err),
})?;
session
.close_tunnel_with_error(
req.client_id,
TunnelClosed(req.message.clone(), req.error_code.clone()),
)
.await;
}
}
PROXY_REQ => {
let header = ProxyHeader::read_from_stream(&mut *stream).await?;

Expand Down
32 changes: 25 additions & 7 deletions ngrok/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use std::{
};

use arc_swap::ArcSwap;
use async_rustls::rustls::{self,};
use async_rustls::rustls::{
self,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures::{
Expand Down Expand Up @@ -74,6 +76,7 @@ pub use crate::internals::{
CommandResp,
Restart,
Stop,
StopTunnel,
Update,
},
raw_session::{
Expand Down Expand Up @@ -657,21 +660,26 @@ impl SessionBuilder {
/// an error.
pub async fn connect(&self) -> Result<Session, ConnectError> {
let (dropref, dropped) = awaitdrop::awaitdrop();
let (inner, incoming) = self.connect_inner(None).await?;
let (inner, mut incoming) = self.connect_inner(None).await?;

let rt = inner.runtime.clone();

let inner = Arc::new(ArcSwap::new(inner.into()));

let session = Session {
_dropref: dropref,
inner: inner.clone(),
};

// store the session for use with StopTunnel
incoming.session = Some(session.clone());

rt.spawn(future::select(
accept_incoming(incoming, inner.clone()).boxed(),
accept_incoming(incoming, inner).boxed(),
dropped.wait(),
));

Ok(Session {
_dropref: dropref,
inner,
})
Ok(session)
}

pub(crate) fn get_or_create_tls_config(&self) -> rustls::ClientConfig {
Expand Down Expand Up @@ -944,6 +952,16 @@ impl Session {
Ok(tunnel)
}

/// Close a tunnel with an error from the remote.
/// Skips the call to unlisten, since the remote has already rejected it.
pub(crate) async fn close_tunnel_with_error(&self, id: impl AsRef<str>, err: AcceptError) {
let id = id.as_ref();
let inner = self.inner.load();
if let Some(tun) = inner.tunnels.write().await.remove(id) {
let _ = tun.tx.send(Err(err)).await;
};
}

/// Close a tunnel with the given ID.
pub async fn close_tunnel(&self, id: impl AsRef<str>) -> Result<(), RpcError> {
let id = id.as_ref();
Expand Down
3 changes: 3 additions & 0 deletions ngrok/src/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ pub enum AcceptError {
/// An error arose during reconnect
#[error("reconnect error")]
Reconnect(#[from] Arc<ConnectError>),
/// The tunnel was closed.
#[error("tunnel closed")]
TunnelClosed(String, String),
}

#[derive(Clone)]
Expand Down

0 comments on commit 650ed9a

Please sign in to comment.