Skip to content

Commit

Permalink
block illegal calls in inpod mode (#935)
Browse files Browse the repository at this point in the history
Part of #928

This removes the obsolete call protections for inpod, and adds
equivilents
  • Loading branch information
howardjohn authored Apr 18, 2024
1 parent 9efd781 commit a2b2393
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 7 deletions.
21 changes: 19 additions & 2 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::fmt::Debug;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
Expand Down Expand Up @@ -93,6 +94,7 @@ pub struct Proxy {
outbound: Outbound,
socks5: Socks5,
policy_watcher: PolicyWatcher,
illegal_ports: Arc<HashSet<u16>>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -155,13 +157,19 @@ impl Proxy {
Self::from_inputs(pi, drain).await
}
pub(super) async fn from_inputs(mut pi: ProxyInputs, drain: Watch) -> Result<Self, Error> {
// illegal_ports are internal ports that clients are not authorized to send to
let mut illegal_ports: HashSet<u16> = HashSet::new();
// We setup all the listeners first so we can capture any errors that should block startup
let inbound = Inbound::new(pi.clone(), drain.clone()).await?;
pi.hbone_port = inbound.address().port();
illegal_ports.insert(inbound.address().port());

let inbound_passthrough = InboundPassthrough::new(pi.clone(), drain.clone()).await?;
illegal_ports.insert(inbound_passthrough.address().port());
let outbound = Outbound::new(pi.clone(), drain.clone()).await?;
illegal_ports.insert(outbound.address().port());
let socks5 = Socks5::new(pi.clone(), drain.clone()).await?;
illegal_ports.insert(socks5.address().port());
let policy_watcher = PolicyWatcher::new(pi.state, drain, pi.connection_manager);

Ok(Proxy {
Expand All @@ -170,13 +178,22 @@ impl Proxy {
outbound,
socks5,
policy_watcher,
illegal_ports: Arc::new(illegal_ports),
})
}

pub async fn run(self) {
let tasks = vec![
tokio::spawn(self.inbound_passthrough.run().in_current_span()),
tokio::spawn(self.inbound.run().in_current_span()),
tokio::spawn(
self.inbound_passthrough
.run(self.illegal_ports.clone())
.in_current_span(),
),
tokio::spawn(
self.inbound
.run(self.illegal_ports.clone())
.in_current_span(),
),
tokio::spawn(self.outbound.run().in_current_span()),
tokio::spawn(self.socks5.run().in_current_span()),
tokio::spawn(self.policy_watcher.run().in_current_span()),
Expand Down
22 changes: 20 additions & 2 deletions src/proxy/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
Expand Down Expand Up @@ -83,7 +84,7 @@ impl Inbound {
self.listener.local_addr().expect("local_addr available")
}

pub(super) async fn run(self) {
pub(super) async fn run(self, illegal_ports: Arc<HashSet<u16>>) {
let acceptor = InboundCertProvider {
state: self.pi.state.clone(),
cert_manager: self.pi.cert_manager.clone(),
Expand All @@ -103,6 +104,7 @@ impl Inbound {
let connection_manager = self.pi.connection_manager.clone();
let drain = sub_drain.clone();
let network = self.pi.cfg.network.clone();
let illegal_ports = illegal_ports.clone();
let drain_deadline = self.pi.cfg.self_termination_deadline;
tokio::task::spawn(async move {
let conn = Connection {
Expand All @@ -128,6 +130,7 @@ impl Inbound {
conn.clone(),
enable_original_source.unwrap_or_default(),
req,
illegal_ports.clone(),
connection_manager.clone(),
)
.map(|status| {
Expand Down Expand Up @@ -285,6 +288,7 @@ impl Inbound {
conn: Connection,
enable_original_source: bool,
req: Request<Incoming>,
illegal_ports: Arc<HashSet<u16>>,
connection_manager: ConnectionManager,
) -> StatusCode {
if req.method() != Method::CONNECT {
Expand Down Expand Up @@ -319,7 +323,21 @@ impl Inbound {
);
return StatusCode::BAD_REQUEST;
};

let illegal_call = if pi.cfg.inpod_enabled {
// User sent a request to pod:15006. This would forward to pod:15006 infinitely
illegal_ports.contains(&upstream_addr.port())
} else {
false // TODO: do we need any check here?
};
if illegal_call {
metrics::log_early_deny(
conn.src,
upstream_addr,
Reporter::destination,
Error::SelfCall,
);
return StatusCode::BAD_REQUEST;
}
// Connection has 15008, swap with the real port
let conn = Connection {
dst: upstream_addr,
Expand Down
22 changes: 19 additions & 3 deletions src/proxy/inbound_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
Expand Down Expand Up @@ -62,12 +63,17 @@ impl InboundPassthrough {
})
}

pub(super) async fn run(self) {
pub(super) fn address(&self) -> SocketAddr {
self.listener.local_addr().expect("local_addr available")
}

pub(super) async fn run(self, illegal_ports: Arc<HashSet<u16>>) {
let accept = async move {
loop {
// Asynchronously wait for an inbound socket.
let socket = self.listener.accept().await;
let pi = self.pi.clone();
let illegal_ports = illegal_ports.clone();

let connection_manager = self.pi.connection_manager.clone();
match socket {
Expand All @@ -78,6 +84,7 @@ impl InboundPassthrough {
pi, // pi cloned above; OK to move
socket::to_canonical(remote),
stream,
illegal_ports,
connection_manager,
)
.await
Expand Down Expand Up @@ -110,12 +117,21 @@ impl InboundPassthrough {
pi: ProxyInputs,
source_addr: SocketAddr,
mut inbound_stream: TcpStream,
illegal_ports: Arc<HashSet<u16>>,
connection_manager: ConnectionManager,
) {
let start = Instant::now();
let dest_addr = socket::orig_dst_addr_or_default(&inbound_stream);
// Check if it is a recursive call when proxy mode is Node.
if pi.cfg.proxy_mode == ProxyMode::Shared && Some(dest_addr.ip()) == pi.cfg.local_ip {
// Check if it is an illegal call to ourself, which could trampoline to illegal addresses or
// lead to infinite loops
let illegal_call = if pi.cfg.inpod_enabled {
// User sent a request to pod:15006. This would forward to pod:15006 infinitely
illegal_ports.contains(&dest_addr.port())
} else {
// User sent a request to the ztunnel directly. This isn't allowed
pi.cfg.proxy_mode == ProxyMode::Shared && Some(dest_addr.ip()) == pi.cfg.local_ip
};
if illegal_call {
metrics::log_early_deny(
source_addr,
dest_addr,
Expand Down
3 changes: 3 additions & 0 deletions src/proxy/outbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,11 @@ impl OutboundConnection {
) {
let start = Instant::now();

// Block calls to ztunnel directly, unless we are in "in-pod".
// For in-pod, this isn't an issue and is useful: this allows things like prometheus scraping ztunnel.
if self.pi.cfg.proxy_mode == ProxyMode::Shared
&& Some(dest_addr.ip()) == self.pi.cfg.local_ip
&& !self.pi.cfg.inpod_enabled
{
metrics::log_early_deny(source_addr, dest_addr, Reporter::source, Error::SelfCall);
return;
Expand Down
1 change: 1 addition & 0 deletions src/proxy/socks5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use anyhow::Result;
use byteorder::{BigEndian, ByteOrder};
use drain::Watch;

use std::net::{IpAddr, SocketAddr};

use tokio::io::AsyncReadExt;
Expand Down

0 comments on commit a2b2393

Please sign in to comment.