-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmain.py
178 lines (115 loc) · 5.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import logging
import pathlib
from typing import List
import requests
import cloudflare
import configparser
import pandas as pd
import os
class App:
def __init__(self):
self.name_prefix = f"[CFPihole]"
self.logger = logging.getLogger("main")
self.whitelist = self.loadWhitelist()
def loadWhitelist(self):
return open("whitelist.txt", "r").read().split("\n")
def run(self):
config = configparser.ConfigParser()
config.read('config.ini')
#check tmp dir
os.makedirs("./tmp", exist_ok=True)
all_domains = []
for list in config["Lists"]:
print ("Setting list " + list)
name_prefix = f"[AdBlock-{list}]"
self.download_file(config["Lists"][list], list)
domains = self.convert_to_domain_list(list)
all_domains = all_domains + domains
unique_domains = pd.unique(all_domains)
# check if the list is already in Cloudflare
cf_lists = cloudflare.get_lists(self.name_prefix)
self.logger.info(f"Number of lists in Cloudflare: {len(cf_lists)}")
# compare the lists size
if len(unique_domains) == sum([l["count"] for l in cf_lists]):
self.logger.warning("Lists are the same size, skipping")
else:
#delete the policy
cf_policies = cloudflare.get_firewall_policies(self.name_prefix)
if len(cf_policies)>0:
cloudflare.delete_firewall_policy(cf_policies[0]["id"])
# delete the lists
for l in cf_lists:
self.logger.info(f"Deleting list {l['name']}")
cloudflare.delete_list(l["id"])
cf_lists = []
# chunk the domains into lists of 1000 and create them
for chunk in self.chunk_list(unique_domains, 1000):
list_name = f"{self.name_prefix} {len(cf_lists) + 1}"
self.logger.info(f"Creating list {list_name}")
_list = cloudflare.create_list(list_name, chunk)
cf_lists.append(_list)
# get the gateway policies
cf_policies = cloudflare.get_firewall_policies(self.name_prefix)
self.logger.info(f"Number of policies in Cloudflare: {len(cf_policies)}")
# setup the gateway policy
if len(cf_policies) == 0:
self.logger.info("Creating firewall policy")
cf_policies = cloudflare.create_gateway_policy(f"{self.name_prefix} Block Ads", [l["id"] for l in cf_lists])
elif len(cf_policies) != 1:
self.logger.error("More than one firewall policy found")
raise Exception("More than one firewall policy found")
else:
self.logger.info("Updating firewall policy")
cloudflare.update_gateway_policy(f"{self.name_prefix} Block Ads", cf_policies[0]["id"], [l["id"] for l in cf_lists])
self.logger.info("Done")
def is_valid_hostname(self, hostname):
import re
if len(hostname) > 255:
return False
hostname = hostname.rstrip(".")
allowed = re.compile('^[a-z0-9]([a-z0-9\-\_]{0,61}[a-z0-9])?$',re.IGNORECASE)
labels = hostname.split(".")
# the TLD must not be all-numeric
if re.match(r"^[0-9]+$", labels[-1]):
return False
return all(allowed.match(x) for x in labels)
def download_file(self, url, name):
self.logger.info(f"Downloading file from {url}")
r = requests.get(url, allow_redirects=True)
path = pathlib.Path("tmp/" + name)
open(path, "wb").write(r.content)
self.logger.info(f"File size: {path.stat().st_size}")
def convert_to_domain_list(self, file_name: str):
with open("tmp/"+file_name, "r") as f:
data = f.read()
# check if the file is a hosts file or a list of domain
is_hosts_file = False
for ip in ["localhost", "127.0.0.1", "::1", "0.0.0.0"]:
if ip in data:
is_hosts_file = True
break
domains = []
for line in data.splitlines():
# skip comments and empty lines
if line.startswith("#") or line.startswith(";") or line == "\n" or line == "":
continue
if is_hosts_file:
# remove the ip address and the trailing newline
domain = line.split()[1].rstrip()
# skip the localhost entry
if domain == "localhost":
continue
else:
domain = line.rstrip()
#Check whitelist
if domain in self.whitelist:
continue
domains.append(domain)
self.logger.info(f"Number of domains: {len(domains)}")
return domains
def chunk_list(self, _list: List[str], n: int):
for i in range(0, len(_list), n):
yield _list[i : i + n]
if __name__ == "__main__":
app = App()
app.run()