-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: IP Address range block support
- Updated Readme.md. - Update package version to 1.2.0. - Added network based permission example. - Refactored package structure by separating concerns. - Refactored handling host check using host type instead of iterating through all validations.
- Loading branch information
Showing
8 changed files
with
231 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from flask import Flask, request, jsonify | ||
|
||
from flask_allowed_hosts import AllowedHosts | ||
|
||
ALLOWED_HOSTS = ["93.184.215.14", "api.example.com"] | ||
|
||
|
||
# Returns a json response if the request IP/hostname is not in the allowed hosts | ||
def custom_on_denied(): | ||
error = {"error": "Oops! looks like you are not allowed to access this page!"} | ||
return jsonify(error), 403 | ||
|
||
|
||
app = Flask(__name__) | ||
allowed_hosts = AllowedHosts(app, allowed_hosts=ALLOWED_HOSTS, on_denied=custom_on_denied) | ||
|
||
|
||
# This endpoint only allows incoming requests from "93.184.215.14" and "api.example.com" | ||
@app.route("/api/greet", methods=["GET"]) | ||
@allowed_hosts.limit() | ||
def greet_endpoint(): | ||
name = request.args.get("name", "Friend") | ||
data = {"message": f"Hello There {name}!"} | ||
return jsonify(data), 200 | ||
|
||
|
||
# This endpoint allows all incoming requests | ||
@app.route("/api/public", methods=["GET"]) | ||
@allowed_hosts.limit(allowed_hosts=["0.0.0.0/0"]) | ||
def public_endpoint(): | ||
data = {"message": f"this is a public endpoint by override"} | ||
return jsonify(data), 200 | ||
|
||
|
||
# This endpoint only allows incoming requests from "127.0.0.1" to "127.0.0.255" | ||
@app.route("/api/network", methods=["GET"]) | ||
@allowed_hosts.limit(allowed_hosts=["127.0.0.0/24"]) | ||
def override_endpoint(): | ||
data = {"message": f"this endpoint is restricted to 127.0.0.0/24"} | ||
return jsonify(data), 200 | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(host='0.0.0.0', port=5000, debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,62 @@ | ||
import os | ||
import socket | ||
|
||
from typing import List, Union | ||
|
||
# Flask modules | ||
from flask import request | ||
|
||
DEBUG = os.environ.get("ALLOWED_HOSTS_DEBUG", "False") == "True" | ||
LOCAL_HOST_VARIANTS = ('localhost', '127.0.0.1', '::1') | ||
|
||
# Python modules | ||
import re | ||
import socket | ||
import ipaddress | ||
from typing import List | ||
|
||
def debug_log(message: str) -> None: | ||
if DEBUG: | ||
print(f"Flask Allowed Hosts -> {message}") | ||
# Local modules | ||
from flask_allowed_hosts.logger import AllowedHostsLogger | ||
|
||
|
||
def get_remote_address() -> str: | ||
return request.remote_addr or "127.0.0.1" | ||
|
||
|
||
def get_hostname_ips(host: str) -> List[str]: | ||
def get_host_ips(host: str) -> List[str]: | ||
try: | ||
host = socket.gethostbyname_ex(host) | ||
debug_log(f"Host: {host}") | ||
AllowedHostsLogger.info(f"Host: {host}") | ||
host_ips = host[2] | ||
debug_log(f"Host IPs: {host_ips}") | ||
AllowedHostsLogger.info(f"Host IPs: {host_ips}") | ||
return host_ips | ||
except socket.gaierror: | ||
debug_log(f"get_hostname_ips error: {host}") | ||
AllowedHostsLogger.error(f"get_host_ips error: {host}") | ||
return [] | ||
|
||
|
||
def is_real_hostname(host: str, request_ip: str) -> bool: | ||
host_ips = get_hostname_ips(host) | ||
return request_ip in host_ips | ||
def is_local_host(host: str) -> bool: | ||
try: | ||
host_ip = socket.gethostbyname(host) | ||
return ipaddress.ip_address(host_ip).is_loopback | ||
except socket.gaierror: | ||
return False | ||
|
||
|
||
def is_local_connection_allowed(host: str, client_ip: str) -> bool: | ||
return host in LOCAL_HOST_VARIANTS and client_ip in ('127.0.0.1', '::1', '::ffff:127.0.0.1') | ||
def is_valid_cidr_network(address: str, strict: bool = False) -> bool: | ||
# Regex pattern match CIDR networks | ||
pattern = r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\/(?:[0-9]|[1-2][0-9]|3[0-2]))' | ||
regex = re.compile(pattern) | ||
|
||
if not regex.fullmatch(address): | ||
return False | ||
|
||
def is_valid_host(request_ip: str, allowed_hosts: Union[List[str], str]) -> bool: | ||
if not allowed_hosts or allowed_hosts in ("*", ["*"]): | ||
debug_log("All hosts are allowed, request was permitted.") | ||
try: | ||
ipaddress.ip_network(address, strict=strict) | ||
return True | ||
except ValueError: | ||
AllowedHostsLogger.error(f"is_valid_cidr_network error: {address}") | ||
return False | ||
|
||
if isinstance(allowed_hosts, str): | ||
allowed_hosts = [allowed_hosts] | ||
|
||
debug_log(f"Request IP: {request_ip}") | ||
def get_host_type(host: str) -> str: | ||
if is_local_host(host): | ||
return "localhost" | ||
|
||
for host in allowed_hosts: | ||
if is_local_connection_allowed(host, request_ip): | ||
debug_log("Localhost connection permitted") | ||
return True | ||
elif is_real_hostname(host, request_ip): | ||
debug_log("Valid Host, request was permitted.") | ||
return True | ||
if is_valid_cidr_network(host): | ||
return "network" | ||
|
||
debug_log("Invalid Host, request was not permitted") | ||
return False | ||
return "host" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
|
||
|
||
class AllowedHostsLogger: | ||
DEBUG = os.environ.get("ALLOWED_HOSTS_DEBUG", "False") == "True" | ||
|
||
@classmethod | ||
def _print(cls, message: str, emoji: str = None) -> None: | ||
if not cls.DEBUG: | ||
return | ||
|
||
debug_message = "Flask Allowed Hosts -> " + message | ||
|
||
if emoji is not None: | ||
debug_message = f"{emoji}{debug_message}" | ||
|
||
print(debug_message) | ||
|
||
@classmethod | ||
def info(cls, message: str) -> None: | ||
cls._print(message, "🔍") | ||
|
||
@classmethod | ||
def error(cls, message: str) -> None: | ||
cls._print(message, "❌") | ||
|
||
@classmethod | ||
def warning(cls, message: str) -> None: | ||
cls._print(message, "⚠️") | ||
|
||
@classmethod | ||
def success(cls, message: str) -> None: | ||
cls._print(message, "✅") | ||
|
||
@classmethod | ||
def custom(cls, message: str, emoji: str) -> None: | ||
cls._print(message, emoji) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Python modules | ||
import re | ||
import socket | ||
import ipaddress | ||
from typing import List, Union | ||
|
||
# Local modules | ||
from flask_allowed_hosts.helpers import get_host_ips, get_host_type | ||
|
||
from flask_allowed_hosts.logger import AllowedHostsLogger | ||
|
||
|
||
class PermissionManager: | ||
|
||
@staticmethod | ||
def _is_local_access_allowed(host: str, request_ip: str) -> bool: | ||
host_ip = socket.gethostbyname(host) | ||
if not ipaddress.ip_address(host_ip).is_loopback: | ||
return False | ||
|
||
return ipaddress.ip_address(request_ip).is_loopback | ||
|
||
@staticmethod | ||
def _is_network_access_allowed(address: str, request_ip: str, strict: bool = False) -> bool: | ||
try: | ||
ip_address = ipaddress.ip_address(request_ip) | ||
host_network = ipaddress.ip_network(address, strict=strict) | ||
network_range = host_network.num_addresses - 1 | ||
AllowedHostsLogger.info(f"IP: {ip_address} - Network: {host_network} (range: {network_range})") | ||
|
||
return ip_address in host_network | ||
except Exception as e: | ||
AllowedHostsLogger.error(f"is_network_access_allowed error: {str(e)}") | ||
return False | ||
|
||
@staticmethod | ||
def _is_host_access_allowed(host: str, request_ip: str) -> bool: | ||
host_ips = get_host_ips(host) | ||
|
||
return request_ip in host_ips | ||
|
||
@classmethod | ||
def is_request_allowed(cls, request_ip: str, allowed_hosts: Union[List[str], str]) -> bool: | ||
if not allowed_hosts or allowed_hosts in ("*", ["*"]): | ||
AllowedHostsLogger.success("All hosts are allowed, request was permitted.") | ||
return True | ||
|
||
if isinstance(allowed_hosts, str): | ||
allowed_hosts = [allowed_hosts] | ||
|
||
AllowedHostsLogger.info(f"Request IP: {request_ip}") | ||
|
||
for host in allowed_hosts: | ||
host_type = get_host_type(host) | ||
|
||
AllowedHostsLogger.info(f"Host Type: {host_type}") | ||
|
||
if host_type == "localhost" and cls._is_local_access_allowed(host, request_ip): | ||
AllowedHostsLogger.success(f"Local Host request was permitted.") | ||
return True | ||
elif host_type == "network" and cls._is_network_access_allowed(host, request_ip): | ||
AllowedHostsLogger.success(f"Network Host request was permitted.") | ||
return True | ||
elif cls._is_host_access_allowed(host, request_ip): | ||
AllowedHostsLogger.success(f"Host request was permitted.") | ||
return True | ||
|
||
AllowedHostsLogger.custom("Invalid Host, request was not permitted", "🚫") | ||
return False |
Oops, something went wrong.