Skip to content

Commit

Permalink
add support for TrafficPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
TheConcierge committed Jul 8, 2024
1 parent 9300e91 commit e7b369c
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 12 deletions.
4 changes: 3 additions & 1 deletion ngrok/src/config/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ pub(crate) struct CommonOpts {
pub(crate) forwards_proto: Option<String>,
// Whether to disable certificate verification for this tunnel.
verify_upstream_tls: Option<bool>,
// DEPRECATED: use traffic_policy instead.
pub(crate) policy: Option<Policy>,
// Policy that defines rules that should be applied to incoming or outgoing
// connections to the edge.
pub(crate) policy: Option<Policy>,
pub(crate) traffic_policy: Option<String>,
}

impl CommonOpts {
Expand Down
17 changes: 15 additions & 2 deletions ngrok/src/config/http.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
borrow::Borrow,
collections::HashMap,
convert::From,
str::FromStr,
};

Expand Down Expand Up @@ -186,7 +187,13 @@ impl TunnelConfig for HttpOptions {
.websocket_tcp_conversion
.then_some(WebsocketTcpConverter {}),
user_agent_filter: self.user_agent_filter(),
policy: self.common_opts.policy.clone().map(From::from),
traffic_policy: if self.common_opts.traffic_policy.is_some() {
self.common_opts.traffic_policy.clone().map(From::from)
} else if self.common_opts.policy.is_some() {
self.common_opts.policy.clone().map(From::from)
} else {
None
},
..Default::default()
};

Expand Down Expand Up @@ -433,7 +440,7 @@ impl HttpTunnelBuilder {
self
}

/// Set the policy for this edge.
/// DEPRECATED: use traffic_policy instead.
pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
where
S: TryInto<Policy>,
Expand All @@ -442,6 +449,12 @@ impl HttpTunnelBuilder {
Ok(self)
}

/// Set policy for this edge.
pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
self.options.common_opts.traffic_policy = Some(policy_str.into());
self
}

pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
self.options.common_opts.for_forwarding_to(to_url);
if let Some(host) = to_url.host_str().filter(|_| self.options.rewrite_host) {
Expand Down
6 changes: 6 additions & 0 deletions ngrok/src/config/policies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ impl Action {
}
}

impl From<Policy> for proto::PolicyWrapper {
fn from(value: Policy) -> Self {
proto::PolicyWrapper::Policy(value.into())
}
}

// transform into the wire protocol format
impl From<Policy> for proto::Policy {
fn from(o: Policy) -> Self {
Expand Down
21 changes: 18 additions & 3 deletions ngrok/src/config/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::HashMap;
use std::{
collections::HashMap,
convert::From,
};

use url::Url;

Expand Down Expand Up @@ -72,7 +75,13 @@ impl TunnelConfig for TcpOptions {

tcp_endpoint.ip_restriction = self.common_opts.ip_restriction();

tcp_endpoint.policy = self.common_opts.policy.clone().map(From::from);
tcp_endpoint.traffic_policy = if self.common_opts.traffic_policy.is_some() {
self.common_opts.traffic_policy.clone().map(From::from)
} else if self.common_opts.policy.is_some() {
self.common_opts.policy.clone().map(From::from)
} else {
None
};

Some(BindOpts::Tcp(tcp_endpoint))
}
Expand Down Expand Up @@ -145,7 +154,7 @@ impl TcpTunnelBuilder {
self
}

/// Set the policy for this edge.
/// DEPRECATED: use traffic_policy instead.
pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
where
S: TryInto<Policy>,
Expand All @@ -154,6 +163,12 @@ impl TcpTunnelBuilder {
Ok(self)
}

/// Set policy for this edge.
pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
self.options.common_opts.traffic_policy = Some(policy_str.into());
self
}

pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
self.options.common_opts.for_forwarding_to(to_url);
self
Expand Down
16 changes: 14 additions & 2 deletions ngrok/src/config/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ impl TunnelConfig for TlsOptions {
tls_endpoint.mutual_tls_at_edge =
(!self.mutual_tlsca.is_empty()).then_some(self.mutual_tlsca.as_slice().into());
tls_endpoint.tls_termination = tls_termination;
tls_endpoint.policy = self.common_opts.policy.clone().map(From::from);
tls_endpoint.traffic_policy = if self.common_opts.traffic_policy.is_some() {
self.common_opts.traffic_policy.clone().map(From::from)
} else if self.common_opts.policy.is_some() {
self.common_opts.policy.clone().map(From::from)
} else {
None
};

Some(BindOpts::Tls(tls_endpoint))
}
Expand Down Expand Up @@ -185,7 +191,7 @@ impl TlsTunnelBuilder {
self
}

/// Set the policy for this edge.
/// DEPRECATED: use traffic_policy instead.
pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
where
S: TryInto<Policy>,
Expand All @@ -194,6 +200,12 @@ impl TlsTunnelBuilder {
Ok(self)
}

/// Set policy for this edge.
pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
self.options.common_opts.traffic_policy = Some(policy_str.into());
self
}

pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
self.options.common_opts.for_forwarding_to(to_url);
self
Expand Down
65 changes: 61 additions & 4 deletions ngrok/src/internals/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use serde::{
},
Deserialize,
Serialize,
Serializer,
};
use thiserror::Error;
use tokio::io::{
Expand Down Expand Up @@ -651,6 +652,20 @@ impl<'de> Deserialize<'de> for ProxyProto {
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum PolicyWrapper {
#[serde(serialize_with = "serialize_policy")]
Policy(Policy),
String(String),
}

impl From<String> for PolicyWrapper {
fn from(value: String) -> Self {
PolicyWrapper::String(value)
}
}

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
#[serde(rename_all = "PascalCase")]
pub struct HttpEndpoint {
Expand Down Expand Up @@ -681,7 +696,8 @@ pub struct HttpEndpoint {
pub websocket_tcp_converter: Option<WebsocketTcpConverter>,
#[serde(rename = "UserAgentFilter")]
pub user_agent_filter: Option<UserAgentFilter>,
pub policy: Option<Policy>,
#[serde(rename = "TrafficPolicy")]
pub traffic_policy: Option<PolicyWrapper>,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
Expand Down Expand Up @@ -807,7 +823,8 @@ pub struct TcpEndpoint {
pub proxy_proto: ProxyProto,
#[serde(rename = "IPRestriction")]
pub ip_restriction: Option<IpRestriction>,
pub policy: Option<Policy>,
#[serde(rename = "TrafficPolicy")]
pub traffic_policy: Option<PolicyWrapper>,
}

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
Expand All @@ -825,7 +842,8 @@ pub struct TlsEndpoint {
pub tls_termination: Option<TlsTermination>,
#[serde(rename = "IPRestriction")]
pub ip_restriction: Option<IpRestriction>,
pub policy: Option<Policy>,
#[serde(rename = "TrafficPolicy")]
pub traffic_policy: Option<PolicyWrapper>,
}

#[derive(Serialize, Deserialize, Debug, Clone, Default)]
Expand Down Expand Up @@ -864,10 +882,49 @@ pub struct Rule {
pub struct Action {
#[serde(rename = "Type")]
pub type_: String,
#[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
#[serde(default, with = "vec_to_json", skip_serializing_if = "is_default")]
pub config: Vec<u8>,
}

// This function converts a Policy into a valid JSON string. This is used so legacy configurations will still work
// using the new string "TrafficPolicy" field.
fn serialize_policy<S: Serializer>(v: &Policy, s: S) -> Result<S::Ok, S::Error> {
let abc = match serde_json::to_string(v) {
Ok(t) => t,
Err(_) => {
return Err(serde::ser::Error::custom(
"policy could not be converted to valid json",
))
}
};
s.serialize_str(&abc)
}

// These are helpers to convert base64 strings to full, real json. The serialize helper also ensures that the resulting
// representation isn't a string-escaped string.
mod vec_to_json {
use serde::{
Deserialize,
Deserializer,
Serialize,
Serializer,
};

pub fn serialize<S: Serializer>(v: &[u8], s: S) -> Result<S::Ok, S::Error> {
let u: serde_json::Value = match serde_json::from_slice(v) {
Ok(k) => k,
Err(_) => return Err(serde::ser::Error::custom("Config is invalid JSON")),
};

u.serialize(s)
}

pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
let s = String::deserialize(d)?;
Ok(s.into_bytes())
}
}

// These are helpers to facilitate the Vec<u8> <-> base64-encoded bytes
// representation that the Go messages use
mod base64bytes {
Expand Down

0 comments on commit e7b369c

Please sign in to comment.