Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNS and TLS proxy #43

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions verified-inference/host/dns_forwarder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import socket
import logging

VSOCK_PORT = 5053

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger()


def get_system_dns():
"""Retrieve the first nameserver from /etc/resolv.conf."""
try:
with open("/etc/resolv.conf", "r") as f:
for line in f:
if line.startswith("nameserver"):
return line.split()[1], 53 # Extract the IP and set port 53
except Exception as e:
logger.error(f"Failed to get system DNS: {e}")

return "8.8.8.8", 53 # Fallback to Google DNS if none found


def forward_dns_to_public(dns_request, dns_ip, dns_port):
"""Forward the DNS request to a public DNS resolver."""
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as dns_sock:
dns_sock.sendto(
dns_request, (dns_ip, dns_port)
) # Send request to the DNS resolver
response, _ = dns_sock.recvfrom(512) # Receive response from the DNS resolver
return response


def vsock_dns_proxy():
"""Listen on a vsock port and forward DNS requests to a public DNS server."""
server_socket = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
server_socket.bind((socket.VMADDR_CID_ANY, VSOCK_PORT))
server_socket.listen(5)
logger.info(f"Listening for DNS requests on vsock port {VSOCK_PORT}...")
dns_ip, dns_port = get_system_dns()
logger.info(f"Using DNS server {dns_ip}:{dns_port}")

while True:
client_socket, _ = server_socket.accept()
logger.info("Received DNS request from enclave.")

try:
# Receive the DNS request
dns_request = client_socket.recv(512)
dns_response = forward_dns_to_public(dns_request, dns_ip, dns_port)

# Send response back to enclave
client_socket.sendall(dns_response)
except Exception as e:
logger.error(f"Error: {e}")
finally:
client_socket.close()


if __name__ == "__main__":
vsock_dns_proxy()
10 changes: 10 additions & 0 deletions verified-inference/host/run_forwarders.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

echo "Starting DNS forwarder..."
nohup python "$SCRIPT_DIR/dns_forwarder.py" > "$SCRIPT_DIR/dns_forwarder.log" 2>&1 &
echo "Started DNS forwarder with PID $!"

echo "Starting traffic forwarder..."
nohup python "$SCRIPT_DIR/traffic_forwarder.py" > "$SCRIPT_DIR/traffic_forwarder.log" 2>&1 &
echo "Started traffic forwarder with PID $!"
19 changes: 19 additions & 0 deletions verified-inference/host/stop_forwarders.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

PIDS=$(pgrep -f "$SCRIPT_DIR/dns_forwarder.py")

if [ -n "$PIDS" ]; then
echo "Killing existing instance(s) of DNS forwarder..."
kill $PIDS
sleep 1
fi

PIDS=$(pgrep -f "$SCRIPT_DIR/traffic_forwarder.py")

if [ -n "$PIDS" ]; then
echo "Killing existing instance(s) of traffic forwarder..."
kill $PIDS
sleep 1
fi

echo "All forwarders have been stopped."
93 changes: 93 additions & 0 deletions verified-inference/host/traffic_forwarder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import socket
import struct
import threading
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger()


class HostProxy:
BUFFER_SIZE = 1024
CID = 3 # The CID of the enclave
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could the enclave CID be a constant? There should also be multiple CIDs for multiple enclaves, isn't it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but according to the usage, the vsock_server.bind((self.CID, self.PORT)) should be connecting to the enclave IIUC. so here shouldn't we provide the enclave CID?

Copy link
Collaborator Author

@kgrofelnik kgrofelnik Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this line doesn't open a connection to an enclave; instead, it binds the socket to a specific address pair (CID, PORT) so it can wait for incoming connections from enclaves.

PORT = 8001 # The port where this proxy listens for traffic from the enclave

def forward(self, source, destination):
"""Forward data between two sockets."""
try:
while True:
data = source.recv(self.BUFFER_SIZE)
if not data:
break
destination.sendall(data)
except Exception as exc:
logger.error(f"Exception in forward: {exc}")
finally:
source.close()
destination.close()

def handle_vsock_connection(self, vsock_client):
"""
Handles a new connection from the enclave over vsock,
extracts the original destination, and forwards traffic.
"""
try:
# Read first 6 bytes which contain the prepended IP and port
header = vsock_client.recv(6)
if len(header) < 6:
logger.error("Error: Incomplete header received")
vsock_client.close()
return

# Extract original IP and port
original_ip_bin, original_port_bin = header[:4], header[4:6]
original_ip = socket.inet_ntoa(
original_ip_bin
) # Convert binary IP to string
original_port = struct.unpack("!H", original_port_bin)[
0
] # Convert binary port to integer

logger.info(f"Forwarding traffic to {original_ip}:{original_port}")

# Connect to the real destination
destination_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
destination_socket.connect((original_ip, original_port))

# Start forwarding traffic in both directions
outgoing_thread = threading.Thread(
target=self.forward, args=(vsock_client, destination_socket)
)
incoming_thread = threading.Thread(
target=self.forward, args=(destination_socket, vsock_client)
)

outgoing_thread.start()
incoming_thread.start()
except Exception as exc:
logger.error(f"Error handling vsock connection: {exc}")
vsock_client.close()

def start(self):
"""Starts the host-side proxy listening on vsock."""
try:
vsock_server = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
vsock_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
vsock_server.bind((self.CID, self.PORT))
vsock_server.listen(10)

logger.info(f"Host proxy listening on vsock {self.CID}:{self.PORT}")

while True:
client_socket, _ = vsock_server.accept()
threading.Thread(
target=self.handle_vsock_connection, args=(client_socket,)
).start()
except Exception as exc:
logger.error(f"HostProxy exception: {exc}")


if __name__ == "__main__":
proxy = HostProxy()
proxy.start()
Loading