From 5753ec14874d786d110a9b9a4a40242023c8c3fa Mon Sep 17 00:00:00 2001 From: jptomoya <4786564+jptomoya@users.noreply.github.com> Date: Sun, 10 Nov 2024 13:02:16 +0900 Subject: [PATCH] Fix SNI handling in Socket --- ptrlib/connection/sock.py | 4 +++- tests/connection/test_sock.py | 37 ++++++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/ptrlib/connection/sock.py b/ptrlib/connection/sock.py index 5386587..4ee92f3 100644 --- a/ptrlib/connection/sock.py +++ b/ptrlib/connection/sock.py @@ -68,8 +68,10 @@ def __init__(self, self.context = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT) self.context.check_hostname = False self.context.verify_mode = _ssl.CERT_NONE - if sni is True: + if not sni: self._sock = self.context.wrap_socket(self._sock) + elif sni is True: + self._sock = self.context.wrap_socket(self._sock, server_hostname=host) else: self._sock = self.context.wrap_socket(self._sock, server_hostname=sni) diff --git a/tests/connection/test_sock.py b/tests/connection/test_sock.py index e5b35ea..22be4b2 100644 --- a/tests/connection/test_sock.py +++ b/tests/connection/test_sock.py @@ -1,3 +1,4 @@ +import json import unittest from socket import gethostbyname from ptrlib import Socket @@ -45,21 +46,35 @@ def test_reset(self): self.assertEqual(b"200 OK" in cm.exception.args[1], True) def test_tls(self): - host = "www.example.com" + host = "check-tls.akamaized.net" + path = "/v1/tlssni.json" - # connect with sni - ip_addr = gethostbyname(host) - sock = Socket(ip_addr, 443, ssl=True, sni=host) - sock.sendline(b'GET / HTTP/1.1\r') - sock.send(b'Host: www.example.com\r\n') + # connect with SNI enabled + sock = Socket(host, 443, ssl=True) + sock.sendline(f'GET {path} HTTP/1.1'.encode() + b'\r') + sock.send(f'Host: {host}'.encode() + b'\r\n') sock.send(b'Connection: close\r\n\r\n') self.assertTrue(int(sock.recvlineafter('Content-Length: ')) > 0) sock.close() - # connect without sni - sock = Socket(host, 443, ssl=True) - sock.sendline(b'GET / HTTP/1.1\r') - sock.send(b'Host: www.example.com\r\n') + # connect with a specific SNI value + sock = Socket(host, 443, ssl=True, sni="example.com") + sock.sendline(f'GET {path} HTTP/1.1'.encode() + b'\r') + sock.send(f'Host: {host}'.encode() + b'\r\n') sock.send(b'Connection: close\r\n\r\n') - self.assertTrue(int(sock.recvlineafter('Content-Length: ')) > 0) + self.assertTrue((contentlength := int(sock.recvlineafter('Content-Length: '))) > 0) + sock.recvuntil(b'\r\n\r\n') + content = json.loads(sock.recvonce(contentlength)) + sock.close() + self.assertEqual(content['tls_sni_value'], "example.com") + + # connect with SNI disabled + sock = Socket(host, 443, ssl=True, sni=False) + sock.sendline(f'GET {path} HTTP/1.1'.encode() + b'\r') + sock.send(f'Host: {host}'.encode() + b'\r\n') + sock.send(b'Connection: close\r\n\r\n') + self.assertTrue((contentlength := int(sock.recvlineafter('Content-Length: '))) > 0) + sock.recvuntil(b'\r\n\r\n') + content = json.loads(sock.recvonce(contentlength)) sock.close() + self.assertEqual(content['tls_sni_status'], "missing")