From 4422a3bc32c40120e3b7c104c5d115fb911c8f87 Mon Sep 17 00:00:00 2001 From: "Jorge Alberto Diaz Orozco (Akiel)" Date: Fri, 27 Jan 2023 21:12:58 +0100 Subject: [PATCH] Check for the validity of the ssl fingerprint on every request --- outline_vpn/outline_vpn.py | 80 +++++++++++++++++++++++++++----------- outline_vpn/utils.py | 30 -------------- test_outline_vpn.py | 8 ++-- 3 files changed, 62 insertions(+), 56 deletions(-) delete mode 100644 outline_vpn/utils.py diff --git a/outline_vpn/outline_vpn.py b/outline_vpn/outline_vpn.py index 9e0f0cb..9c9b40e 100644 --- a/outline_vpn/outline_vpn.py +++ b/outline_vpn/outline_vpn.py @@ -5,8 +5,7 @@ from dataclasses import dataclass import requests - -from outline_vpn.utils import check_ssl_fingerprint +from urllib3 import PoolManager @dataclass @@ -29,6 +28,24 @@ class OutlineServerErrorException(Exception): pass +class _FingerprintAdapter(requests.adapters.HTTPAdapter): + """ + This adapter injected into the requests session will check that the + fingerprint for the certificate matches for every request + """ + def __init__(self, fingerprint=None, **kwargs): + self.fingerprint = str(fingerprint) + super(_FingerprintAdapter, self).__init__(**kwargs) + + def init_poolmanager(self, connections, maxsize, block=False): + self.poolmanager = PoolManager( + num_pools=connections, + maxsize=maxsize, + block=block, + assert_fingerprint=self.fingerprint, + ) + + class OutlineVPN: """ An Outline VPN connection @@ -38,13 +55,17 @@ def __init__(self, api_url: str, cert_sha256: str = None): self.api_url = api_url if cert_sha256: - check_ssl_fingerprint(api_url, cert_sha256) + session = requests.Session() + session.mount("https://", _FingerprintAdapter(cert_sha256)) + self.session = session + else: + self.session = requests.Session() def get_keys(self): """Get all keys in the outline server""" - response = requests.get(f"{self.api_url}/access-keys/", verify=False) + response = self.session.get(f"{self.api_url}/access-keys/", verify=False) if response.status_code == 200 and "accessKeys" in response.json(): - response_metrics = requests.get( + response_metrics = self.session.get( f"{self.api_url}/metrics/transfer", verify=False ) if ( @@ -75,7 +96,7 @@ def get_keys(self): def create_key(self, key_name=None) -> OutlineKey: """Create a new key""" - response = requests.post(f"{self.api_url}/access-keys/", verify=False) + response = self.session.post(f"{self.api_url}/access-keys/", verify=False) if response.status_code == 201: key = response.json() outline_key = OutlineKey( @@ -96,7 +117,7 @@ def create_key(self, key_name=None) -> OutlineKey: def delete_key(self, key_id: int) -> bool: """Delete a key""" - response = requests.delete(f"{self.api_url}/access-keys/{key_id}", verify=False) + response = self.session.delete(f"{self.api_url}/access-keys/{key_id}", verify=False) return response.status_code == 204 def rename_key(self, key_id: int, name: str): @@ -105,7 +126,7 @@ def rename_key(self, key_id: int, name: str): "name": (None, name), } - response = requests.put( + response = self.session.put( f"{self.api_url}/access-keys/{key_id}/name", files=files, verify=False ) return response.status_code == 204 @@ -114,14 +135,14 @@ def add_data_limit(self, key_id: int, limit_bytes: int) -> bool: """Set data limit for a key (in bytes)""" data = {"limit": {"bytes": limit_bytes}} - response = requests.put( + response = self.session.put( f"{self.api_url}/access-keys/{key_id}/data-limit", json=data, verify=False ) return response.status_code == 204 def delete_data_limit(self, key_id: int) -> bool: """Removes data limit for a key""" - response = requests.delete( + response = self.session.delete( f"{self.api_url}/access-keys/{key_id}/data-limit", verify=False ) return response.status_code == 204 @@ -135,7 +156,7 @@ def get_transferred_data(self): "3":752221577 } }""" - response = requests.get(f"{self.api_url}/metrics/transfer", verify=False) + response = self.session.get(f"{self.api_url}/metrics/transfer", verify=False) if ( response.status_code >= 400 or "bytesTransferredByUserId" not in response.json() @@ -156,54 +177,69 @@ def get_server_information(self): "hostnameForAccessKeys":"example.com" } """ - response = requests.get(f"{self.api_url}/server", verify=False) + response = self.session.get(f"{self.api_url}/server", verify=False) if response.status_code != 200: - raise OutlineServerErrorException("Unable to get information about the server") + raise OutlineServerErrorException( + "Unable to get information about the server" + ) return response.json() def set_server_name(self, name: str) -> bool: """Renames the server""" data = {"name": name} - response = requests.put(f"{self.api_url}/name", verify=False, json=data) + response = self.session.put(f"{self.api_url}/name", verify=False, json=data) return response.status_code == 204 def set_hostname(self, hostname: str) -> bool: """Changes the hostname for access keys. Must be a valid hostname or IP address.""" data = {"hostname": hostname} - response = requests.put(f"{self.api_url}/server/hostname-for-access-keys", verify=False, json=data) + response = self.session.put( + f"{self.api_url}/server/hostname-for-access-keys", verify=False, json=data + ) return response.status_code == 204 def get_metrics_status(self) -> bool: """Returns whether metrics is being shared""" - response = requests.get(f"{self.api_url}/metrics/enabled", verify=False) + response = self.session.get(f"{self.api_url}/metrics/enabled", verify=False) return response.json().get("metricsEnabled") def set_metrics_status(self, status: bool) -> bool: """Enables or disables sharing of metrics""" data = {"metricsEnabled": status} - response = requests.put(f"{self.api_url}/metrics/enabled", verify=False, json=data) + response = self.session.put( + f"{self.api_url}/metrics/enabled", verify=False, json=data + ) return response.status_code == 204 def set_port_new_for_access_keys(self, port: int) -> bool: """Changes the default port for newly created access keys. This can be a port already used for access keys.""" data = {"port": port} - response = requests.put(f"{self.api_url}/server/port-for-new-access-keys", verify=False, json=data) + response = self.session.put( + f"{self.api_url}/server/port-for-new-access-keys", verify=False, json=data + ) if response.status_code == 400: raise OutlineServerErrorException( - "The requested port wasn't an integer from 1 through 65535, or the request had no port parameter.") + "The requested port wasn't an integer from 1 through 65535, or the request had no port parameter." + ) elif response.status_code == 409: - raise OutlineServerErrorException("The requested port was already in use by another service.") + raise OutlineServerErrorException( + "The requested port was already in use by another service." + ) return response.status_code == 204 def set_data_limit_for_all_keys(self, limit_bytes: int) -> bool: """Sets a data transfer limit for all access keys.""" data = {"limit": {"bytes": limit_bytes}} - response = requests.put(f"{self.api_url}/server/access-key-data-limit", verify=False, json=data) + response = self.session.put( + f"{self.api_url}/server/access-key-data-limit", verify=False, json=data + ) return response.status_code == 204 def delete_data_limit_for_all_keys(self) -> bool: """Removes the access key data limit, lifting data transfer restrictions on all access keys.""" - response = requests.delete(f"{self.api_url}/server/access-key-data-limit", verify=False) + response = self.session.delete( + f"{self.api_url}/server/access-key-data-limit", verify=False + ) return response.status_code == 204 diff --git a/outline_vpn/utils.py b/outline_vpn/utils.py deleted file mode 100644 index b8a615b..0000000 --- a/outline_vpn/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import ssl -import socket -import hashlib -from urllib.parse import urlparse - - -class OutlineConnectionError(Exception): - pass - - -class OutlineInvalidFingerprintError(Exception): - pass - - -def check_ssl_fingerprint(api_url: str, cert_sha256: str) -> bool: - url = urlparse(api_url) - address = url.hostname - port = url.port - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - wrapped_socket = ssl.wrap_socket(sock) - try: - wrapped_socket.connect((address, port)) - except Exception as e: - raise OutlineConnectionError(f"Connection Error: {e}") - else: - der_cert = wrapped_socket.getpeercert(True) - thumb_sha256 = hashlib.sha256(der_cert).hexdigest() - if thumb_sha256.upper() != cert_sha256.upper(): - raise OutlineInvalidFingerprintError("Invalid fingerprint!") - return True diff --git a/test_outline_vpn.py b/test_outline_vpn.py index f323224..6bab188 100644 --- a/test_outline_vpn.py +++ b/test_outline_vpn.py @@ -6,6 +6,7 @@ import re import pytest +import requests from outline_vpn.outline_vpn import OutlineVPN @@ -19,8 +20,7 @@ def client() -> OutlineVPN: api_url = re.sub("https://.*?:", "https://127.0.0.1:", api_data.get("apiUrl")) client = OutlineVPN( - api_url=api_url, - cert_sha256=api_data.get("certSha256") + api_url=api_url, cert_sha256=api_data.get("certSha256") ) # pylint: disable=W0621 yield client @@ -88,8 +88,8 @@ def test_metrics_status(client: OutlineVPN): def test_data_limit_for_all_keys(client: OutlineVPN): assert client.set_data_limit_for_all_keys(1024 * 1024 * 20) assert client.delete_data_limit_for_all_keys() - - + + def test_get_transferred_data(client: OutlineVPN): """Call the method and assert it responds something""" data = client.get_transferred_data()