diff --git a/ngrok/src/config/common.rs b/ngrok/src/config/common.rs index 89cc03b..b440a44 100644 --- a/ngrok/src/config/common.rs +++ b/ngrok/src/config/common.rs @@ -206,9 +206,11 @@ pub(crate) struct CommonOpts { pub(crate) forwards_proto: Option, // Whether to disable certificate verification for this tunnel. verify_upstream_tls: Option, + // DEPRECATED: use traffic_policy instead. + pub(crate) policy: Option, // Policy that defines rules that should be applied to incoming or outgoing // connections to the edge. - pub(crate) policy: Option, + pub(crate) traffic_policy: Option, } impl CommonOpts { diff --git a/ngrok/src/config/http.rs b/ngrok/src/config/http.rs index 4f54c6a..13403e8 100644 --- a/ngrok/src/config/http.rs +++ b/ngrok/src/config/http.rs @@ -1,6 +1,7 @@ use std::{ borrow::Borrow, collections::HashMap, + convert::From, str::FromStr, }; @@ -186,7 +187,13 @@ impl TunnelConfig for HttpOptions { .websocket_tcp_conversion .then_some(WebsocketTcpConverter {}), user_agent_filter: self.user_agent_filter(), - policy: self.common_opts.policy.clone().map(From::from), + traffic_policy: if self.common_opts.traffic_policy.is_some() { + self.common_opts.traffic_policy.clone().map(From::from) + } else if self.common_opts.policy.is_some() { + self.common_opts.policy.clone().map(From::from) + } else { + None + }, ..Default::default() }; @@ -433,7 +440,7 @@ impl HttpTunnelBuilder { self } - /// Set the policy for this edge. + /// DEPRECATED: use traffic_policy instead. pub fn policy(&mut self, s: S) -> Result<&mut Self, S::Error> where S: TryInto, @@ -442,6 +449,12 @@ impl HttpTunnelBuilder { Ok(self) } + /// Set policy for this edge. + pub fn traffic_policy(&mut self, policy_str: impl Into) -> &mut Self { + self.options.common_opts.traffic_policy = Some(policy_str.into()); + self + } + pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self { self.options.common_opts.for_forwarding_to(to_url); if let Some(host) = to_url.host_str().filter(|_| self.options.rewrite_host) { diff --git a/ngrok/src/config/policies.rs b/ngrok/src/config/policies.rs index 1d28b1a..f01d7fa 100644 --- a/ngrok/src/config/policies.rs +++ b/ngrok/src/config/policies.rs @@ -163,6 +163,12 @@ impl Action { } } +impl From for proto::PolicyWrapper { + fn from(value: Policy) -> Self { + proto::PolicyWrapper::Policy(value.into()) + } +} + // transform into the wire protocol format impl From for proto::Policy { fn from(o: Policy) -> Self { diff --git a/ngrok/src/config/tcp.rs b/ngrok/src/config/tcp.rs index 33f23ac..36956b4 100644 --- a/ngrok/src/config/tcp.rs +++ b/ngrok/src/config/tcp.rs @@ -1,4 +1,7 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + convert::From, +}; use url::Url; @@ -72,7 +75,13 @@ impl TunnelConfig for TcpOptions { tcp_endpoint.ip_restriction = self.common_opts.ip_restriction(); - tcp_endpoint.policy = self.common_opts.policy.clone().map(From::from); + tcp_endpoint.traffic_policy = if self.common_opts.traffic_policy.is_some() { + self.common_opts.traffic_policy.clone().map(From::from) + } else if self.common_opts.policy.is_some() { + self.common_opts.policy.clone().map(From::from) + } else { + None + }; Some(BindOpts::Tcp(tcp_endpoint)) } @@ -145,7 +154,7 @@ impl TcpTunnelBuilder { self } - /// Set the policy for this edge. + /// DEPRECATED: use traffic_policy instead. pub fn policy(&mut self, s: S) -> Result<&mut Self, S::Error> where S: TryInto, @@ -154,6 +163,12 @@ impl TcpTunnelBuilder { Ok(self) } + /// Set policy for this edge. + pub fn traffic_policy(&mut self, policy_str: impl Into) -> &mut Self { + self.options.common_opts.traffic_policy = Some(policy_str.into()); + self + } + pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self { self.options.common_opts.for_forwarding_to(to_url); self diff --git a/ngrok/src/config/tls.rs b/ngrok/src/config/tls.rs index 69fd037..988fb6d 100644 --- a/ngrok/src/config/tls.rs +++ b/ngrok/src/config/tls.rs @@ -92,7 +92,13 @@ impl TunnelConfig for TlsOptions { tls_endpoint.mutual_tls_at_edge = (!self.mutual_tlsca.is_empty()).then_some(self.mutual_tlsca.as_slice().into()); tls_endpoint.tls_termination = tls_termination; - tls_endpoint.policy = self.common_opts.policy.clone().map(From::from); + tls_endpoint.traffic_policy = if self.common_opts.traffic_policy.is_some() { + self.common_opts.traffic_policy.clone().map(From::from) + } else if self.common_opts.policy.is_some() { + self.common_opts.policy.clone().map(From::from) + } else { + None + }; Some(BindOpts::Tls(tls_endpoint)) } @@ -185,7 +191,7 @@ impl TlsTunnelBuilder { self } - /// Set the policy for this edge. + /// DEPRECATED: use traffic_policy instead. pub fn policy(&mut self, s: S) -> Result<&mut Self, S::Error> where S: TryInto, @@ -194,6 +200,12 @@ impl TlsTunnelBuilder { Ok(self) } + /// Set policy for this edge. + pub fn traffic_policy(&mut self, policy_str: impl Into) -> &mut Self { + self.options.common_opts.traffic_policy = Some(policy_str.into()); + self + } + pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self { self.options.common_opts.for_forwarding_to(to_url); self diff --git a/ngrok/src/internals/proto.rs b/ngrok/src/internals/proto.rs index 7dfa9ba..ce7c5bb 100644 --- a/ngrok/src/internals/proto.rs +++ b/ngrok/src/internals/proto.rs @@ -20,6 +20,7 @@ use serde::{ }, Deserialize, Serialize, + Serializer, }; use thiserror::Error; use tokio::io::{ @@ -651,6 +652,20 @@ impl<'de> Deserialize<'de> for ProxyProto { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +pub enum PolicyWrapper { + #[serde(serialize_with = "serialize_policy")] + Policy(Policy), + String(String), +} + +impl From for PolicyWrapper { + fn from(value: String) -> Self { + PolicyWrapper::String(value) + } +} + #[derive(Serialize, Deserialize, Debug, Clone, Default)] #[serde(rename_all = "PascalCase")] pub struct HttpEndpoint { @@ -681,7 +696,8 @@ pub struct HttpEndpoint { pub websocket_tcp_converter: Option, #[serde(rename = "UserAgentFilter")] pub user_agent_filter: Option, - pub policy: Option, + #[serde(rename = "TrafficPolicy")] + pub traffic_policy: Option, } #[derive(Debug, Clone, Copy, Serialize, Deserialize)] @@ -807,7 +823,8 @@ pub struct TcpEndpoint { pub proxy_proto: ProxyProto, #[serde(rename = "IPRestriction")] pub ip_restriction: Option, - pub policy: Option, + #[serde(rename = "TrafficPolicy")] + pub traffic_policy: Option, } #[derive(Serialize, Deserialize, Debug, Clone, Default)] @@ -825,7 +842,8 @@ pub struct TlsEndpoint { pub tls_termination: Option, #[serde(rename = "IPRestriction")] pub ip_restriction: Option, - pub policy: Option, + #[serde(rename = "TrafficPolicy")] + pub traffic_policy: Option, } #[derive(Serialize, Deserialize, Debug, Clone, Default)] @@ -864,10 +882,49 @@ pub struct Rule { pub struct Action { #[serde(rename = "Type")] pub type_: String, - #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")] + #[serde(default, with = "vec_to_json", skip_serializing_if = "is_default")] pub config: Vec, } +// This function converts a Policy into a valid JSON string. This is used so legacy configurations will still work +// using the new string "TrafficPolicy" field. +fn serialize_policy(v: &Policy, s: S) -> Result { + let abc = match serde_json::to_string(v) { + Ok(t) => t, + Err(_) => { + return Err(serde::ser::Error::custom( + "policy could not be converted to valid json", + )) + } + }; + s.serialize_str(&abc) +} + +// These are helpers to convert base64 strings to full, real json. The serialize helper also ensures that the resulting +// representation isn't a string-escaped string. +mod vec_to_json { + use serde::{ + Deserialize, + Deserializer, + Serialize, + Serializer, + }; + + pub fn serialize(v: &[u8], s: S) -> Result { + let u: serde_json::Value = match serde_json::from_slice(v) { + Ok(k) => k, + Err(_) => return Err(serde::ser::Error::custom("Config is invalid JSON")), + }; + + u.serialize(s) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let s = String::deserialize(d)?; + Ok(s.into_bytes()) + } +} + // These are helpers to facilitate the Vec <-> base64-encoded bytes // representation that the Go messages use mod base64bytes {