diff --git a/opensnitch/rule.py b/opensnitch/rule.py index e29080746e..c28bf0600e 100644 --- a/opensnitch/rule.py +++ b/opensnitch/rule.py @@ -18,18 +18,19 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import logging from threading import Lock +import sqlite3 class Rule: ACCEPT = 0 DROP = 1 - def __init__(self): - self.app_path = None - self.address = None - self.port = None - self.proto = None - self.verdict = Rule.ACCEPT - + def __init__( self, app_path=None, verdict=ACCEPT, address=None, port=None, proto=None ): + self.app_path = app_path + self.verdict = verdict + self.address = address + self.port = port + self.proto = proto + def matches( self, c ): if self.app_path != c.app_path: return False @@ -46,11 +47,11 @@ def matches( self, c ): else: return True -# TODO: Implement rules persistance on file. class Rules: def __init__(self): self.mutex = Lock() - self.rules = [] + self.db = RulesDB() + self.rules = self.db.load_rules() def get_verdict( self, connection ): with self.mutex: @@ -69,12 +70,43 @@ def add_rule( self, connection, verdict, apply_to_all = False ): r = Rule() r.verdict = verdict r.app_path = connection.app_path - if apply_to_all is False: + + if apply_to_all is True: + self.db.remove_all_app_rules(r.app_path) + for rule in self.rules: + if rule.app_path == r.app_path: + self.rules.remove(rule) + elif apply_to_all is False: r.address = connection.dst_addr r.port = connection.dst_port r.proto = connection.proto - + self.rules.append(r) + self.db.save_rule(r) + +class RulesDB: + DB_PATH = "opensnitch.db" + + def __init__(self): + self.conn = sqlite3.connect(RulesDB.DB_PATH) + self._create_table() + + def _create_table(self): + c = self.conn.cursor() + c.execute("CREATE TABLE IF NOT EXISTS rules (app_path TEXT, verdict INTEGER, address TEXT, port INTEGER, proto TEXT)") + + def load_rules(self): + c = self.conn.cursor() + c.execute("SELECT * FROM rules") + return [Rule(*item) for item in c.fetchall()] + + def save_rule( self, rule ): + c = self.conn.cursor() + c.execute("INSERT INTO rules VALUES (?, ?, ?, ?, ?)", (rule.app_path, rule.verdict, rule.address, rule.port, rule.proto,)) + self.conn.commit() + def remove_all_app_rules ( self, app_path ): + c = self.conn.cursor() + c.execute("DELETE FROM rules WHERE app_path=?", (app_path,)) + self.conn.commit() -