Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for the validity of the ssl fingerprint on every request #11

Merged
merged 1 commit into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions outline_vpn/outline_vpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from dataclasses import dataclass

import requests

from outline_vpn.utils import check_ssl_fingerprint
from urllib3 import PoolManager


@dataclass
Expand All @@ -29,6 +28,24 @@ class OutlineServerErrorException(Exception):
pass


class _FingerprintAdapter(requests.adapters.HTTPAdapter):
"""
This adapter injected into the requests session will check that the
fingerprint for the certificate matches for every request
"""
def __init__(self, fingerprint=None, **kwargs):
self.fingerprint = str(fingerprint)
super(_FingerprintAdapter, self).__init__(**kwargs)

def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = PoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
assert_fingerprint=self.fingerprint,
)


class OutlineVPN:
"""
An Outline VPN connection
Expand All @@ -38,13 +55,17 @@ def __init__(self, api_url: str, cert_sha256: str = None):
self.api_url = api_url

if cert_sha256:
check_ssl_fingerprint(api_url, cert_sha256)
session = requests.Session()
session.mount("https://", _FingerprintAdapter(cert_sha256))
self.session = session
else:
self.session = requests.Session()

def get_keys(self):
"""Get all keys in the outline server"""
response = requests.get(f"{self.api_url}/access-keys/", verify=False)
response = self.session.get(f"{self.api_url}/access-keys/", verify=False)
if response.status_code == 200 and "accessKeys" in response.json():
response_metrics = requests.get(
response_metrics = self.session.get(
f"{self.api_url}/metrics/transfer", verify=False
)
if (
Expand Down Expand Up @@ -75,7 +96,7 @@ def get_keys(self):

def create_key(self, key_name=None) -> OutlineKey:
"""Create a new key"""
response = requests.post(f"{self.api_url}/access-keys/", verify=False)
response = self.session.post(f"{self.api_url}/access-keys/", verify=False)
if response.status_code == 201:
key = response.json()
outline_key = OutlineKey(
Expand All @@ -96,7 +117,7 @@ def create_key(self, key_name=None) -> OutlineKey:

def delete_key(self, key_id: int) -> bool:
"""Delete a key"""
response = requests.delete(f"{self.api_url}/access-keys/{key_id}", verify=False)
response = self.session.delete(f"{self.api_url}/access-keys/{key_id}", verify=False)
return response.status_code == 204

def rename_key(self, key_id: int, name: str):
Expand All @@ -105,7 +126,7 @@ def rename_key(self, key_id: int, name: str):
"name": (None, name),
}

response = requests.put(
response = self.session.put(
f"{self.api_url}/access-keys/{key_id}/name", files=files, verify=False
)
return response.status_code == 204
Expand All @@ -114,14 +135,14 @@ def add_data_limit(self, key_id: int, limit_bytes: int) -> bool:
"""Set data limit for a key (in bytes)"""
data = {"limit": {"bytes": limit_bytes}}

response = requests.put(
response = self.session.put(
f"{self.api_url}/access-keys/{key_id}/data-limit", json=data, verify=False
)
return response.status_code == 204

def delete_data_limit(self, key_id: int) -> bool:
"""Removes data limit for a key"""
response = requests.delete(
response = self.session.delete(
f"{self.api_url}/access-keys/{key_id}/data-limit", verify=False
)
return response.status_code == 204
Expand All @@ -135,7 +156,7 @@ def get_transferred_data(self):
"3":752221577
}
}"""
response = requests.get(f"{self.api_url}/metrics/transfer", verify=False)
response = self.session.get(f"{self.api_url}/metrics/transfer", verify=False)
if (
response.status_code >= 400
or "bytesTransferredByUserId" not in response.json()
Expand All @@ -156,54 +177,69 @@ def get_server_information(self):
"hostnameForAccessKeys":"example.com"
}
"""
response = requests.get(f"{self.api_url}/server", verify=False)
response = self.session.get(f"{self.api_url}/server", verify=False)
if response.status_code != 200:
raise OutlineServerErrorException("Unable to get information about the server")
raise OutlineServerErrorException(
"Unable to get information about the server"
)
return response.json()

def set_server_name(self, name: str) -> bool:
"""Renames the server"""
data = {"name": name}
response = requests.put(f"{self.api_url}/name", verify=False, json=data)
response = self.session.put(f"{self.api_url}/name", verify=False, json=data)
return response.status_code == 204

def set_hostname(self, hostname: str) -> bool:
"""Changes the hostname for access keys.
Must be a valid hostname or IP address."""
data = {"hostname": hostname}
response = requests.put(f"{self.api_url}/server/hostname-for-access-keys", verify=False, json=data)
response = self.session.put(
f"{self.api_url}/server/hostname-for-access-keys", verify=False, json=data
Fixed Show fixed Hide fixed
)
return response.status_code == 204

def get_metrics_status(self) -> bool:
"""Returns whether metrics is being shared"""
response = requests.get(f"{self.api_url}/metrics/enabled", verify=False)
response = self.session.get(f"{self.api_url}/metrics/enabled", verify=False)
return response.json().get("metricsEnabled")

def set_metrics_status(self, status: bool) -> bool:
"""Enables or disables sharing of metrics"""
data = {"metricsEnabled": status}
response = requests.put(f"{self.api_url}/metrics/enabled", verify=False, json=data)
response = self.session.put(
f"{self.api_url}/metrics/enabled", verify=False, json=data
Fixed Show fixed Hide fixed
)
return response.status_code == 204

def set_port_new_for_access_keys(self, port: int) -> bool:
"""Changes the default port for newly created access keys.
This can be a port already used for access keys."""
data = {"port": port}
response = requests.put(f"{self.api_url}/server/port-for-new-access-keys", verify=False, json=data)
response = self.session.put(
f"{self.api_url}/server/port-for-new-access-keys", verify=False, json=data
Fixed Show fixed Hide fixed
)
if response.status_code == 400:
raise OutlineServerErrorException(
"The requested port wasn't an integer from 1 through 65535, or the request had no port parameter.")
"The requested port wasn't an integer from 1 through 65535, or the request had no port parameter."
)
elif response.status_code == 409:
raise OutlineServerErrorException("The requested port was already in use by another service.")
raise OutlineServerErrorException(
"The requested port was already in use by another service."
)
return response.status_code == 204

def set_data_limit_for_all_keys(self, limit_bytes: int) -> bool:
"""Sets a data transfer limit for all access keys."""
data = {"limit": {"bytes": limit_bytes}}
response = requests.put(f"{self.api_url}/server/access-key-data-limit", verify=False, json=data)
response = self.session.put(
f"{self.api_url}/server/access-key-data-limit", verify=False, json=data
Fixed Show fixed Hide fixed
)
return response.status_code == 204

def delete_data_limit_for_all_keys(self) -> bool:
"""Removes the access key data limit, lifting data transfer restrictions on all access keys."""
response = requests.delete(f"{self.api_url}/server/access-key-data-limit", verify=False)
response = self.session.delete(
f"{self.api_url}/server/access-key-data-limit", verify=False
Fixed Show fixed Hide fixed
)
return response.status_code == 204
30 changes: 0 additions & 30 deletions outline_vpn/utils.py

This file was deleted.

8 changes: 4 additions & 4 deletions test_outline_vpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re

import pytest
import requests

from outline_vpn.outline_vpn import OutlineVPN

Expand All @@ -19,8 +20,7 @@ def client() -> OutlineVPN:
api_url = re.sub("https://.*?:", "https://127.0.0.1:", api_data.get("apiUrl"))

client = OutlineVPN(
api_url=api_url,
cert_sha256=api_data.get("certSha256")
api_url=api_url, cert_sha256=api_data.get("certSha256")
) # pylint: disable=W0621
yield client

Expand Down Expand Up @@ -88,8 +88,8 @@ def test_metrics_status(client: OutlineVPN):
def test_data_limit_for_all_keys(client: OutlineVPN):
assert client.set_data_limit_for_all_keys(1024 * 1024 * 20)
assert client.delete_data_limit_for_all_keys()


def test_get_transferred_data(client: OutlineVPN):
"""Call the method and assert it responds something"""
data = client.get_transferred_data()
Expand Down