Skip to content

Commit

Permalink
feat: IP Address range block support
Browse files Browse the repository at this point in the history
- Updated Readme.md.
- Update package version to 1.2.0.
- Added network based permission example.
- Refactored package structure by separating concerns.
- Refactored handling host check using host type instead of iterating through all validations.
  • Loading branch information
riad-azz committed Oct 10, 2024
1 parent 243a616 commit 1e2a832
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 71 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Flask Allowed Hosts

This extension provides a way to restrict access to your Flask application based on the incoming request's hostname or
IP address.
IP address or IP address range (network).

## Features

- Restrict access by hostname/IP address.
- Per-route configuration options.
- Customize denied access behavior.
- Two usage options: class-based or decorator-based.
- Restrict access by hostname, IP address or IP address range (network).

## Installation

Expand Down
44 changes: 44 additions & 0 deletions examples/network_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from flask import Flask, request, jsonify

from flask_allowed_hosts import AllowedHosts

ALLOWED_HOSTS = ["93.184.215.14", "api.example.com"]


# Returns a json response if the request IP/hostname is not in the allowed hosts
def custom_on_denied():
error = {"error": "Oops! looks like you are not allowed to access this page!"}
return jsonify(error), 403


app = Flask(__name__)
allowed_hosts = AllowedHosts(app, allowed_hosts=ALLOWED_HOSTS, on_denied=custom_on_denied)


# This endpoint only allows incoming requests from "93.184.215.14" and "api.example.com"
@app.route("/api/greet", methods=["GET"])
@allowed_hosts.limit()
def greet_endpoint():
name = request.args.get("name", "Friend")
data = {"message": f"Hello There {name}!"}
return jsonify(data), 200


# This endpoint allows all incoming requests
@app.route("/api/public", methods=["GET"])
@allowed_hosts.limit(allowed_hosts=["0.0.0.0/0"])
def public_endpoint():
data = {"message": f"this is a public endpoint by override"}
return jsonify(data), 200


# This endpoint only allows incoming requests from "127.0.0.1" to "127.0.0.255"
@app.route("/api/network", methods=["GET"])
@allowed_hosts.limit(allowed_hosts=["127.0.0.0/24"])
def override_endpoint():
data = {"message": f"this endpoint is restricted to 127.0.0.0/24"}
return jsonify(data), 200


if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
36 changes: 22 additions & 14 deletions flask_allowed_hosts/allowed_hosts.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,53 @@
# Flask modules
from flask import abort, Flask

# Other modules
# Python modules
from functools import wraps
from typing import List, Callable, Union

# Local modules
from flask_allowed_hosts.validators import validate_limit_parameters
from flask_allowed_hosts.helpers import is_valid_host, get_remote_address
from flask_allowed_hosts.helpers import get_remote_address
from flask_allowed_hosts.validators import ConfigValidator
from flask_allowed_hosts.permission_manager import PermissionManager


class AllowedHosts:
def __init__(self, app=None, allowed_hosts: Union[List[str], str] = None, on_denied: Callable = None):
validate_limit_parameters(allowed_hosts, on_denied)

def __init__(self, app=None, allowed_hosts: Union[List[str], str] = None, on_denied: Callable = None):
# Configurations
self.app = app
self.on_denied = on_denied
self.allowed_hosts = allowed_hosts
self.on_denied = ConfigValidator.validate_on_denied(on_denied)
self.allowed_hosts = ConfigValidator.validate_allowed_hosts(allowed_hosts)

# Initialization
if app is not None:
self.init_app(app)

def init_app(self, app: Flask):
if self.allowed_hosts is None:
self.allowed_hosts = app.config.get('ALLOWED_HOSTS', ["*"])
allowed_hosts_config = app.config.get('ALLOWED_HOSTS', ["*"])
self.allowed_hosts = ConfigValidator.validate_allowed_hosts(allowed_hosts_config)

if self.on_denied is None:
self.on_denied = app.config.get('ALLOWED_HOSTS_ON_DENIED', None)
on_denied_config = app.config.get('ALLOWED_HOSTS_ON_DENIED', None)
self.on_denied = ConfigValidator.validate_on_denied(on_denied_config)

def limit(self, allowed_hosts: Union[List[str], str] = None, on_denied: Callable = None):
validate_limit_parameters(allowed_hosts, on_denied)

if allowed_hosts is None:
allowed_hosts = self.allowed_hosts
else:
allowed_hosts = ConfigValidator.validate_allowed_hosts(allowed_hosts)

if on_denied is None:
on_denied = self.on_denied
else:
on_denied = ConfigValidator.validate_on_denied(on_denied)

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
request_ip = get_remote_address()
if is_valid_host(request_ip, allowed_hosts):
if PermissionManager.is_request_allowed(request_ip, allowed_hosts):
return func(*args, **kwargs)

if callable(on_denied):
Expand All @@ -55,14 +61,16 @@ def wrapper(*args, **kwargs):


# For backward compatibility

def limit_hosts(allowed_hosts: Union[List[str], str] = None, on_denied: Callable = None):
validate_limit_parameters(allowed_hosts, on_denied)
on_denied = ConfigValidator.validate_on_denied(on_denied)
allowed_hosts = ConfigValidator.validate_allowed_hosts(allowed_hosts)

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
request_ip = get_remote_address()
if is_valid_host(request_ip, allowed_hosts):
if PermissionManager.is_request_allowed(request_ip, allowed_hosts):
return func(*args, **kwargs)

if callable(on_denied):
Expand Down
69 changes: 35 additions & 34 deletions flask_allowed_hosts/helpers.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,62 @@
import os
import socket

from typing import List, Union

# Flask modules
from flask import request

DEBUG = os.environ.get("ALLOWED_HOSTS_DEBUG", "False") == "True"
LOCAL_HOST_VARIANTS = ('localhost', '127.0.0.1', '::1')

# Python modules
import re
import socket
import ipaddress
from typing import List

def debug_log(message: str) -> None:
if DEBUG:
print(f"Flask Allowed Hosts -> {message}")
# Local modules
from flask_allowed_hosts.logger import AllowedHostsLogger


def get_remote_address() -> str:
return request.remote_addr or "127.0.0.1"


def get_hostname_ips(host: str) -> List[str]:
def get_host_ips(host: str) -> List[str]:
try:
host = socket.gethostbyname_ex(host)
debug_log(f"Host: {host}")
AllowedHostsLogger.info(f"Host: {host}")
host_ips = host[2]
debug_log(f"Host IPs: {host_ips}")
AllowedHostsLogger.info(f"Host IPs: {host_ips}")
return host_ips
except socket.gaierror:
debug_log(f"get_hostname_ips error: {host}")
AllowedHostsLogger.error(f"get_host_ips error: {host}")
return []


def is_real_hostname(host: str, request_ip: str) -> bool:
host_ips = get_hostname_ips(host)
return request_ip in host_ips
def is_local_host(host: str) -> bool:
try:
host_ip = socket.gethostbyname(host)
return ipaddress.ip_address(host_ip).is_loopback
except socket.gaierror:
return False


def is_local_connection_allowed(host: str, client_ip: str) -> bool:
return host in LOCAL_HOST_VARIANTS and client_ip in ('127.0.0.1', '::1', '::ffff:127.0.0.1')
def is_valid_cidr_network(address: str, strict: bool = False) -> bool:
# Regex pattern match CIDR networks
pattern = r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\/(?:[0-9]|[1-2][0-9]|3[0-2]))'
regex = re.compile(pattern)

if not regex.fullmatch(address):
return False

def is_valid_host(request_ip: str, allowed_hosts: Union[List[str], str]) -> bool:
if not allowed_hosts or allowed_hosts in ("*", ["*"]):
debug_log("All hosts are allowed, request was permitted.")
try:
ipaddress.ip_network(address, strict=strict)
return True
except ValueError:
AllowedHostsLogger.error(f"is_valid_cidr_network error: {address}")
return False

if isinstance(allowed_hosts, str):
allowed_hosts = [allowed_hosts]

debug_log(f"Request IP: {request_ip}")
def get_host_type(host: str) -> str:
if is_local_host(host):
return "localhost"

for host in allowed_hosts:
if is_local_connection_allowed(host, request_ip):
debug_log("Localhost connection permitted")
return True
elif is_real_hostname(host, request_ip):
debug_log("Valid Host, request was permitted.")
return True
if is_valid_cidr_network(host):
return "network"

debug_log("Invalid Host, request was not permitted")
return False
return "host"
37 changes: 37 additions & 0 deletions flask_allowed_hosts/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os


class AllowedHostsLogger:
DEBUG = os.environ.get("ALLOWED_HOSTS_DEBUG", "False") == "True"

@classmethod
def _print(cls, message: str, emoji: str = None) -> None:
if not cls.DEBUG:
return

debug_message = "Flask Allowed Hosts -> " + message

if emoji is not None:
debug_message = f"{emoji}{debug_message}"

print(debug_message)

@classmethod
def info(cls, message: str) -> None:
cls._print(message, "🔍")

@classmethod
def error(cls, message: str) -> None:
cls._print(message, "❌")

@classmethod
def warning(cls, message: str) -> None:
cls._print(message, "⚠️")

@classmethod
def success(cls, message: str) -> None:
cls._print(message, "✅")

@classmethod
def custom(cls, message: str, emoji: str) -> None:
cls._print(message, emoji)
69 changes: 69 additions & 0 deletions flask_allowed_hosts/permission_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Python modules
import re
import socket
import ipaddress
from typing import List, Union

# Local modules
from flask_allowed_hosts.helpers import get_host_ips, get_host_type

from flask_allowed_hosts.logger import AllowedHostsLogger


class PermissionManager:

@staticmethod
def _is_local_access_allowed(host: str, request_ip: str) -> bool:
host_ip = socket.gethostbyname(host)
if not ipaddress.ip_address(host_ip).is_loopback:
return False

return ipaddress.ip_address(request_ip).is_loopback

@staticmethod
def _is_network_access_allowed(address: str, request_ip: str, strict: bool = False) -> bool:
try:
ip_address = ipaddress.ip_address(request_ip)
host_network = ipaddress.ip_network(address, strict=strict)
network_range = host_network.num_addresses - 1
AllowedHostsLogger.info(f"IP: {ip_address} - Network: {host_network} (range: {network_range})")

return ip_address in host_network
except Exception as e:
AllowedHostsLogger.error(f"is_network_access_allowed error: {str(e)}")
return False

@staticmethod
def _is_host_access_allowed(host: str, request_ip: str) -> bool:
host_ips = get_host_ips(host)

return request_ip in host_ips

@classmethod
def is_request_allowed(cls, request_ip: str, allowed_hosts: Union[List[str], str]) -> bool:
if not allowed_hosts or allowed_hosts in ("*", ["*"]):
AllowedHostsLogger.success("All hosts are allowed, request was permitted.")
return True

if isinstance(allowed_hosts, str):
allowed_hosts = [allowed_hosts]

AllowedHostsLogger.info(f"Request IP: {request_ip}")

for host in allowed_hosts:
host_type = get_host_type(host)

AllowedHostsLogger.info(f"Host Type: {host_type}")

if host_type == "localhost" and cls._is_local_access_allowed(host, request_ip):
AllowedHostsLogger.success(f"Local Host request was permitted.")
return True
elif host_type == "network" and cls._is_network_access_allowed(host, request_ip):
AllowedHostsLogger.success(f"Network Host request was permitted.")
return True
elif cls._is_host_access_allowed(host, request_ip):
AllowedHostsLogger.success(f"Host request was permitted.")
return True

AllowedHostsLogger.custom("Invalid Host, request was not permitted", "🚫")
return False
Loading

0 comments on commit 1e2a832

Please sign in to comment.