diff --git a/nxc/database.py b/nxc/database.py index 3f93d7db5..a782fedc6 100644 --- a/nxc/database.py +++ b/nxc/database.py @@ -1,13 +1,19 @@ -import sys import configparser import shutil -from sqlalchemy import create_engine -from sqlite3 import connect +import sys from os import mkdir from os.path import exists from os.path import join as path_join +from pathlib import Path +from sqlite3 import connect +from threading import Lock + +from sqlalchemy import create_engine, MetaData +from sqlalchemy.exc import IllegalStateChangeError +from sqlalchemy.orm import sessionmaker, scoped_session from nxc.loaders.protocolloader import ProtocolLoader +from nxc.logger import nxc_logger from nxc.paths import WORKSPACE_DIR @@ -62,7 +68,7 @@ def create_workspace(workspace_name, p_loader=None): else: print(f"[*] Creating {workspace_name} workspace") mkdir(path_join(WORKSPACE_DIR, workspace_name)) - + if p_loader is None: p_loader = ProtocolLoader() protocols = p_loader.get_protocols() @@ -94,4 +100,40 @@ def delete_workspace(workspace_name): def initialize_db(): if not exists(path_join(WORKSPACE_DIR, "default")): - create_workspace("default") \ No newline at end of file + create_workspace("default") + + +class BaseDB: + def __init__(self, db_engine): + self.db_engine = db_engine + self.db_path = self.db_engine.url.database + self.protocol = Path(self.db_path).stem.upper() + self.metadata = MetaData() + self.reflect_tables() + session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) + + session = scoped_session(session_factory) + self.sess = session() + self.lock = Lock() + + def reflect_tables(self): + raise NotImplemented("Reflect tables not implemented") + + def shutdown_db(self): + try: + self.sess.close() + # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: + # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and + # this would cause an unexpected state change to + except IllegalStateChangeError as e: + nxc_logger.debug(f"Error while closing session db object: {e}") + + def clear_database(self): + for table in self.metadata.sorted_tables: + self.db_execute(table.delete()) + + def db_execute(self, *args): + self.lock.acquire() + res = self.sess.execute(*args) + self.lock.release() + return res diff --git a/nxc/protocols/ftp/database.py b/nxc/protocols/ftp/database.py index aff68bc25..a0fff2126 100644 --- a/nxc/protocols/ftp/database.py +++ b/nxc/protocols/ftp/database.py @@ -1,31 +1,23 @@ -from pathlib import Path +import sys + +from sqlalchemy import Table, select, delete, func from sqlalchemy.dialects.sqlite import Insert -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy import MetaData, Table, select, delete, func from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) + +from nxc.database import BaseDB from nxc.logger import nxc_logger -import sys -class database: +class database(BaseDB): def __init__(self, db_engine): self.CredentialsTable = None self.HostsTable = None self.LoggedinRelationsTable = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - Session = scoped_session(session_factory) - self.sess = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -80,26 +72,13 @@ def reflect_tables(self): ) sys.exit() - def shutdown_db(self): - try: - self.sess.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.sess.execute(table.delete()) - def add_host(self, host, port, banner): """Check if this host is already in the DB, if not add it""" hosts = [] updated_ids = [] q = select(self.HostsTable).filter(self.HostsTable.c.host == host) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() # create new host if not results: @@ -133,7 +112,7 @@ def add_host(self, host, port, banner): update_columns = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns) - self.sess.execute(q, hosts) # .scalar() + self.db_execute(q, hosts) # .scalar() # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted if updated_ids: nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}") @@ -143,8 +122,9 @@ def add_credential(self, username, password): """Check if this credential has already been added to the database, if not add it in.""" credentials = [] - q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username), func.lower(self.CredentialsTable.c.password) == func.lower(password)) - results = self.sess.execute(q).all() + q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username), + func.lower(self.CredentialsTable.c.password) == func.lower(password)) + results = self.db_execute(q).all() # add new credential if not results: @@ -170,10 +150,11 @@ def add_credential(self, username, password): # TODO: find a way to abstract this away to a single Upsert call q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id) update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"} - q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users) + q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, + set_=update_columns_users) nxc_logger.debug(f"Adding credentials: {credentials}") - self.sess.execute(q_users, credentials) # .scalar() + self.db_execute(q_users, credentials) # .scalar() # hacky way to get cred_id since we can't use returning() yet if len(credentials) == 1: @@ -187,7 +168,7 @@ def remove_credentials(self, creds_id): for cred_id in creds_id: q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id) del_hosts.append(q) - self.sess.execute(q) + self.db_execute(q) def is_credential_valid(self, credential_id): """Check if this credential ID is valid.""" @@ -195,7 +176,7 @@ def is_credential_valid(self, credential_id): self.CredentialsTable.c.id == credential_id, self.CredentialsTable.c.password is not None, ) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_credential(self, username, password): @@ -203,7 +184,7 @@ def get_credential(self, username, password): self.CredentialsTable.c.username == username, self.CredentialsTable.c.password == password, ) - results = self.sess.execute(q).first() + results = self.db_execute(q).first() if results is not None: return results.id @@ -220,12 +201,12 @@ def get_credentials(self, filter_term=None): else: q = select(self.CredentialsTable) - return self.sess.execute(q).all() + return self.db_execute(q).all() def is_host_valid(self, host_id): """Check if this host ID is valid.""" q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None): @@ -235,26 +216,26 @@ def get_hosts(self, filter_term=None): # if we're returning a single host by ID if self.is_host_valid(filter_term): q = q.filter(self.HostsTable.c.id == filter_term) - results = self.sess.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by host elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(self.HostsTable.c.host.like(like_term)) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"FTP get_hosts() - results: {results}") return results def is_user_valid(self, cred_id): """Check if this User ID is valid.""" q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_user(self, username): q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username)) - return self.sess.execute(q).all() + return self.db_execute(q).all() def get_users(self, filter_term=None): q = select(self.CredentialsTable) @@ -265,14 +246,14 @@ def get_users(self, filter_term=None): elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term)) - return self.sess.execute(q).all() + return self.db_execute(q).all() def add_loggedin_relation(self, cred_id, host_id): relation_query = select(self.LoggedinRelationsTable).filter( self.LoggedinRelationsTable.c.credid == cred_id, self.LoggedinRelationsTable.c.hostid == host_id, ) - results = self.sess.execute(relation_query).all() + results = self.db_execute(relation_query).all() # only add one if one doesn't already exist if not results: @@ -282,7 +263,7 @@ def add_loggedin_relation(self, cred_id, host_id): # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id) - self.sess.execute(q, [relation]) # .scalar() + self.db_execute(q, [relation]) # .scalar() inserted_id_results = self.get_loggedin_relations(cred_id, host_id) nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}") return inserted_id_results[0].id @@ -295,7 +276,7 @@ def get_loggedin_relations(self, cred_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id) if host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - return self.sess.execute(q).all() + return self.db_execute(q).all() def remove_loggedin_relations(self, cred_id=None, host_id=None): q = delete(self.LoggedinRelationsTable) @@ -303,7 +284,7 @@ def remove_loggedin_relations(self, cred_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id) elif host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - self.sess.execute(q) + self.db_execute(q) def add_directory_listing(self, lir_id, data): pass diff --git a/nxc/protocols/ldap/database.py b/nxc/protocols/ldap/database.py index 9ca4b740c..2f08e9566 100644 --- a/nxc/protocols/ldap/database.py +++ b/nxc/protocols/ldap/database.py @@ -1,30 +1,20 @@ -from pathlib import Path -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy import MetaData, Table +import sys + +from sqlalchemy import Table from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) -from nxc.logger import nxc_logger -import sys + +from nxc.database import BaseDB -class database: +class database(BaseDB): def __init__(self, db_engine): self.CredentialsTable = None self.HostsTable = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - # this is still named "conn" when it is the session object; TODO: rename - self.conn = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -59,16 +49,3 @@ def reflect_tables(self): [-] Then remove the nxc {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB""" ) sys.exit() - - def shutdown_db(self): - try: - self.conn.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.conn.execute(table.delete()) diff --git a/nxc/protocols/mssql/database.py b/nxc/protocols/mssql/database.py index 4782e8bb4..6ff90802b 100755 --- a/nxc/protocols/mssql/database.py +++ b/nxc/protocols/mssql/database.py @@ -1,37 +1,24 @@ -from pathlib import Path -from sqlalchemy import MetaData, func, Table, select, insert, update, delete -from sqlalchemy.dialects.sqlite import Insert # used for upsert -from sqlalchemy.exc import ( - IllegalStateChangeError, - NoInspectionAvailable, - NoSuchTableError, -) -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy.exc import SAWarning +import sys import warnings + +from sqlalchemy import func, select, insert, update, delete, Table +from sqlalchemy.dialects.sqlite import Insert # used for upsert +from sqlalchemy.exc import SAWarning, NoInspectionAvailable, NoSuchTableError + +from nxc.database import BaseDB from nxc.logger import nxc_logger -import sys # if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings warnings.filterwarnings("ignore", category=SAWarning) -class database: +class database(BaseDB): def __init__(self, db_engine): self.HostsTable = None self.UsersTable = None self.AdminRelationsTable = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - # this is still named "conn" when it is the session object; TODO: rename - self.conn = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -83,19 +70,6 @@ def reflect_tables(self): ) sys.exit() - def shutdown_db(self): - try: - self.conn.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.conn.execute(table.delete()) - def add_host(self, ip, hostname, domain, os, instances): """ Check if this host has already been added to the database, if not, add it in. @@ -107,7 +81,7 @@ def add_host(self, ip, hostname, domain, os, instances): hosts = [] q = select(self.HostsTable).filter(self.HostsTable.c.ip == ip) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"mssql add_host() - hosts returned: {results}") host_data = { @@ -142,7 +116,7 @@ def add_host(self, ip, hostname, domain, os, instances): q = Insert(self.HostsTable) update_columns = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns) - self.conn.execute(q, hosts) + self.db_execute(q, hosts) def add_credential(self, credtype, domain, username, password, pillaged_from=None): """Check if this credential has already been added to the database, if not add it in.""" @@ -165,7 +139,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non func.lower(self.UsersTable.c.username) == func.lower(username), func.lower(self.UsersTable.c.credtype) == func.lower(credtype), ) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() if not results: user_data = { @@ -176,15 +150,16 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non "pillaged_from_hostid": pillaged_from, } q = insert(self.UsersTable).values(user_data) # .returning(self.UsersTable.c.id) - self.conn.execute(q) # .first() + self.db_execute(q) # .first() else: for user in results: # might be able to just remove this if check, but leaving it in for now if not user[3] and not user[4] and not user[5]: q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id) - results = self.conn.execute(q) # .first() + results = self.db_execute(q) # .first() - nxc_logger.debug(f"add_credential(credtype={credtype}, domain={domain}, username={username}, password={password}, pillaged_from={pillaged_from})") + nxc_logger.debug( + f"add_credential(credtype={credtype}, domain={domain}, username={username}, password={password}, pillaged_from={pillaged_from})") return user_rowid def remove_credentials(self, creds_id): @@ -193,12 +168,12 @@ def remove_credentials(self, creds_id): for cred_id in creds_id: q = delete(self.UsersTable).filter(self.UsersTable.c.id == cred_id) del_hosts.append(q) - self.conn.execute(q) + self.db_execute(q) def add_admin_user(self, credtype, domain, username, password, host, user_id=None): if user_id: q = select(self.UsersTable).filter(self.UsersTable.c.id == user_id) - users = self.conn.execute(q).all() + users = self.db_execute(q).all() else: q = select(self.UsersTable).filter( self.UsersTable.c.credtype == credtype, @@ -206,12 +181,12 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password, ) - users = self.conn.execute(q).all() + users = self.db_execute(q).all() nxc_logger.debug(f"Users: {users}") like_term = func.lower(f"%{host}%") q = q.filter(self.HostsTable.c.ip.like(like_term)) - hosts = self.conn.execute(q).all() + hosts = self.db_execute(q).all() nxc_logger.debug(f"Hosts: {hosts}") if users is not None and hosts is not None: @@ -224,10 +199,10 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non self.AdminRelationsTable.c.userid == user_id, self.AdminRelationsTable.c.hostid == host_id, ) - links = self.conn.execute(q).all() + links = self.db_execute(q).all() if not links: - self.conn.execute(insert(self.AdminRelationsTable).values(link)) + self.db_execute(insert(self.AdminRelationsTable).values(link)) def get_admin_relations(self, user_id=None, host_id=None): if user_id: @@ -237,7 +212,7 @@ def get_admin_relations(self, user_id=None, host_id=None): else: q = select(self.AdminRelationsTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def remove_admin_relation(self, user_ids=None, host_ids=None): q = delete(self.AdminRelationsTable) @@ -247,7 +222,7 @@ def remove_admin_relation(self, user_ids=None, host_ids=None): elif host_ids: for host_id in host_ids: q = q.filter(self.AdminRelationsTable.c.hostid == host_id) - self.conn.execute(q) + self.db_execute(q) def is_credential_valid(self, credential_id): """Check if this credential ID is valid.""" @@ -255,7 +230,7 @@ def is_credential_valid(self, credential_id): self.UsersTable.c.id == credential_id, self.UsersTable.c.password is not None, ) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -273,12 +248,12 @@ def get_credentials(self, filter_term=None, cred_type=None): else: q = select(self.UsersTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def is_host_valid(self, host_id): """Check if this host ID is valid.""" q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None, domain=None): @@ -288,7 +263,7 @@ def get_hosts(self, filter_term=None, domain=None): # if we're returning a single host by ID if self.is_host_valid(filter_term): q = q.filter(self.HostsTable.c.id == filter_term) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by domain controllers @@ -299,6 +274,7 @@ def get_hosts(self, filter_term=None, domain=None): # if we're filtering by ip/hostname elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") - q = select(self.HostsTable).filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term)) + q = select(self.HostsTable).filter( + self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term)) - return self.conn.execute(q).all() + return self.db_execute(q).all() diff --git a/nxc/protocols/rdp/database.py b/nxc/protocols/rdp/database.py index 7a34c5a5b..2053b16d1 100644 --- a/nxc/protocols/rdp/database.py +++ b/nxc/protocols/rdp/database.py @@ -1,31 +1,20 @@ -from pathlib import Path +import sys -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy import MetaData, Table +from sqlalchemy import Table from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) -from nxc.logger import nxc_logger -import sys + +from nxc.database import BaseDB -class database: +class database(BaseDB): def __init__(self, db_engine): self.CredentialsTable = None self.HostsTable = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - # this is still named "conn" when it is the session object; TODO: rename - self.conn = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -62,16 +51,3 @@ def reflect_tables(self): [-] Then remove the {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB""" ) sys.exit() - - def shutdown_db(self): - try: - self.conn.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.conn.execute(table.delete()) diff --git a/nxc/protocols/smb/database.py b/nxc/protocols/smb/database.py index f32d7beeb..37d02a551 100755 --- a/nxc/protocols/smb/database.py +++ b/nxc/protocols/smb/database.py @@ -1,27 +1,25 @@ import base64 +import sys import warnings from datetime import datetime -from pathlib import Path +from typing import Optional -from sqlalchemy import MetaData, func, Table, select, delete +from sqlalchemy import func, Table, select, delete from sqlalchemy.dialects.sqlite import Insert # used for upsert from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) from sqlalchemy.exc import SAWarning -from sqlalchemy.orm import sessionmaker, scoped_session +from nxc.database import BaseDB from nxc.logger import nxc_logger -import sys -from typing import Optional # if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings warnings.filterwarnings("ignore", category=SAWarning) -class database: +class database(BaseDB): def __init__(self, db_engine): self.HostsTable = None self.UsersTable = None @@ -35,16 +33,7 @@ def __init__(self, db_engine): self.DpapiBackupkey = None self.DpapiSecrets = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - # this is still named "conn" when it is the session object; TODO: rename - self.conn = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -199,39 +188,26 @@ def reflect_tables(self): ) sys.exit() - def shutdown_db(self): - try: - self.conn.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.conn.execute(table.delete()) - # pull/545 def add_host( - self, - ip, - hostname, - domain, - os, - smbv1, - signing, - spooler=None, - zerologon=None, - petitpotam=None, - dc=None, + self, + ip, + hostname, + domain, + os, + smbv1, + signing, + spooler=None, + zerologon=None, + petitpotam=None, + dc=None, ): """Check if this host has already been added to the database, if not, add it in.""" hosts = [] updated_ids = [] q = select(self.HostsTable).filter(self.HostsTable.c.ip == ip) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() # create new host if not results: @@ -284,7 +260,7 @@ def add_host( update_columns = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns) - self.conn.execute(q, hosts) # .scalar() + self.db_execute(q, hosts) # .scalar() # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted if updated_ids: nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}") @@ -295,7 +271,8 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi credentials = [] groups = [] - if (group_id and not self.is_group_valid(group_id)) or (pillaged_from and not self.is_host_valid(pillaged_from)): + if (group_id and not self.is_group_valid(group_id)) or ( + pillaged_from and not self.is_host_valid(pillaged_from)): nxc_logger.debug("Invalid group or host") return @@ -304,7 +281,7 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi func.lower(self.UsersTable.c.username) == func.lower(username), func.lower(self.UsersTable.c.credtype) == func.lower(credtype), ) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() # add new credential if not results: @@ -346,12 +323,12 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi q_users = q_users.on_conflict_do_update(index_elements=self.UsersTable.primary_key, set_=update_columns_users) nxc_logger.debug(f"Adding credentials: {credentials}") - self.conn.execute(q_users, credentials) # .scalar() + self.db_execute(q_users, credentials) # .scalar() if groups: q_groups = Insert(self.GroupRelationsTable) - self.conn.execute(q_groups, groups) + self.db_execute(q_groups, groups) def remove_credentials(self, creds_id): """Removes a credential ID from the database""" @@ -359,14 +336,17 @@ def remove_credentials(self, creds_id): for cred_id in creds_id: q = delete(self.UsersTable).filter(self.UsersTable.c.id == cred_id) del_hosts.append(q) - self.conn.execute(q) + self.db_execute(q) def add_admin_user(self, credtype, domain, username, password, host, user_id=None): add_links = [] creds_q = select(self.UsersTable) - creds_q = creds_q.filter(self.UsersTable.c.id == user_id) if user_id else creds_q.filter(func.lower(self.UsersTable.c.credtype) == func.lower(credtype), func.lower(self.UsersTable.c.domain) == func.lower(domain), func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password) - users = self.conn.execute(creds_q) + creds_q = creds_q.filter(self.UsersTable.c.id == user_id) if user_id else creds_q.filter( + func.lower(self.UsersTable.c.credtype) == func.lower(credtype), + func.lower(self.UsersTable.c.domain) == func.lower(domain), + func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password) + users = self.db_execute(creds_q) hosts = self.get_hosts(host) if users and hosts: @@ -378,7 +358,7 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non self.AdminRelationsTable.c.userid == user_id, self.AdminRelationsTable.c.hostid == host_id, ) - links = self.conn.execute(admin_relations_select).all() + links = self.db_execute(admin_relations_select).all() if not links: add_links.append(link) @@ -386,7 +366,7 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non admin_relations_insert = Insert(self.AdminRelationsTable) if add_links: - self.conn.execute(admin_relations_insert, add_links) + self.db_execute(admin_relations_insert, add_links) def get_admin_relations(self, user_id=None, host_id=None): if user_id: @@ -396,7 +376,7 @@ def get_admin_relations(self, user_id=None, host_id=None): else: q = select(self.AdminRelationsTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def remove_admin_relation(self, user_ids=None, host_ids=None): q = delete(self.AdminRelationsTable) @@ -406,7 +386,7 @@ def remove_admin_relation(self, user_ids=None, host_ids=None): elif host_ids: for host_id in host_ids: q = q.filter(self.AdminRelationsTable.c.hostid == host_id) - self.conn.execute(q) + self.db_execute(q) def is_credential_valid(self, credential_id): """Check if this credential ID is valid.""" @@ -414,7 +394,7 @@ def is_credential_valid(self, credential_id): self.UsersTable.c.id == credential_id, self.UsersTable.c.password is not None, ) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -432,7 +412,7 @@ def get_credentials(self, filter_term=None, cred_type=None): else: q = select(self.UsersTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_credential(self, cred_type, domain, username, password): q = select(self.UsersTable).filter( @@ -441,22 +421,22 @@ def get_credential(self, cred_type, domain, username, password): self.UsersTable.c.password == password, self.UsersTable.c.credtype == cred_type, ) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() return results.id def is_credential_local(self, credential_id): q = select(self.UsersTable.c.domain).filter(self.UsersTable.c.id == credential_id) - user_domain = self.conn.execute(q).all() + user_domain = self.db_execute(q).all() if user_domain: q = select(self.HostsTable).filter(func.lower(self.HostsTable.c.id) == func.lower(user_domain)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def is_host_valid(self, host_id): """Check if this host ID is valid.""" q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None, domain=None): @@ -466,7 +446,7 @@ def get_hosts(self, filter_term=None, domain=None): # if we're returning a single host by ID if self.is_host_valid(filter_term): q = q.filter(self.HostsTable.c.id == filter_term) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by domain controllers @@ -491,14 +471,14 @@ def get_hosts(self, filter_term=None, domain=None): elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"smb hosts() - results: {results}") return results def is_group_valid(self, group_id): """Check if this group ID is valid.""" q = select(self.GroupsTable).filter(self.GroupsTable.c.id == group_id) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() valid = bool(results) nxc_logger.debug(f"is_group_valid(groupID={group_id}) => {valid}") @@ -530,7 +510,7 @@ def add_group(self, domain, name, rid=None, member_count_ad=None): # insert the group and get the returned id right away, this can be refactored when we can use RETURNING q = Insert(self.GroupsTable) - self.conn.execute(q, groups) + self.db_execute(q, groups) new_group_data = self.get_groups(group_name=group_data["name"], group_domain=group_data["domain"]) returned_id = [new_group_data[0].id] nxc_logger.debug(f"Inserted group with ID: {returned_id[0]}") @@ -561,7 +541,7 @@ def add_group(self, domain, name, rid=None, member_count_ad=None): update_columns = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=self.GroupsTable.primary_key, set_=update_columns) - self.conn.execute(q, groups) + self.db_execute(q, groups) # TODO: always return a list and fix code references to not expect a single integer # if updated_ids: @@ -572,7 +552,7 @@ def get_groups(self, filter_term=None, group_name=None, group_domain=None): """Return groups from the database""" if filter_term and self.is_group_valid(filter_term): q = select(self.GroupsTable).filter(self.GroupsTable.c.id == filter_term) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] elif group_name and group_domain: @@ -586,9 +566,10 @@ def get_groups(self, filter_term=None, group_name=None, group_domain=None): else: q = select(self.GroupsTable).filter() - results = self.conn.execute(q).all() + results = self.db_execute(q).all() - nxc_logger.debug(f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}") + nxc_logger.debug( + f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}") return results def get_group_relations(self, user_id=None, group_id=None): @@ -602,7 +583,7 @@ def get_group_relations(self, user_id=None, group_id=None): elif group_id: q = select(self.GroupRelationsTable).filter(self.GroupRelationsTable.c.groupid == group_id) - return self.conn.execute(q).all() + return self.db_execute(q).all() def remove_group_relations(self, user_id=None, group_id=None): q = delete(self.GroupRelationsTable) @@ -610,12 +591,12 @@ def remove_group_relations(self, user_id=None, group_id=None): q = q.filter(self.GroupRelationsTable.c.userid == user_id) elif group_id: q = q.filter(self.GroupRelationsTable.c.groupid == group_id) - self.conn.execute(q) + self.db_execute(q) def is_user_valid(self, user_id): """Check if this User ID is valid.""" q = select(self.UsersTable).filter(self.UsersTable.c.id == user_id) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_users(self, filter_term=None): @@ -627,14 +608,14 @@ def get_users(self, filter_term=None): elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(func.lower(self.UsersTable.c.username).like(like_term)) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_user(self, domain, username): q = select(self.UsersTable).filter( func.lower(self.UsersTable.c.domain) == func.lower(domain), func.lower(self.UsersTable.c.username) == func.lower(username), ) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_domain_controllers(self, domain=None): return self.get_hosts(filter_term="dc", domain=domain) @@ -642,7 +623,7 @@ def get_domain_controllers(self, domain=None): def is_share_valid(self, share_id): """Check if this share ID is valid.""" q = select(self.SharesTable).filter(self.SharesTable.c.id == share_id) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}") return len(results) > 0 @@ -656,7 +637,7 @@ def add_share(self, host_id, user_id, name, remark, read, write): "read": read, "write": write, } - self.conn.execute( + self.db_execute( Insert(self.SharesTable).on_conflict_do_nothing(), # .returning(self.SharesTable.c.id), share_data, ) # .scalar_one() @@ -669,7 +650,7 @@ def get_shares(self, filter_term=None): q = select(self.SharesTable).filter(self.SharesTable.c.name.like(like_term)) else: q = select(self.SharesTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_shares_by_access(self, permissions, share_id=None): permissions = permissions.lower() @@ -680,17 +661,17 @@ def get_shares_by_access(self, permissions, share_id=None): q = q.filter(self.SharesTable.c.read == 1) if "w" in permissions: q = q.filter(self.SharesTable.c.write == 1) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_users_with_share_access(self, host_id, share_name, permissions): permissions = permissions.lower() - q = select(self.SharesTable.c.userid).filter(self.SharesTable.c.name == share_name, self.SharesTable.c.hostid == host_id) + q = select(self.SharesTable.c.userid).filter(self.SharesTable.c.name == share_name, + self.SharesTable.c.hostid == host_id) if "r" in permissions: q = q.filter(self.SharesTable.c.read == 1) if "w" in permissions: q = q.filter(self.SharesTable.c.write == 1) - return self.conn.execute(q).all() - + return self.db_execute(q).all() def add_domain_backupkey(self, domain: str, pvk: bytes): """ @@ -699,7 +680,7 @@ def add_domain_backupkey(self, domain: str, pvk: bytes): :pvk is the domain backupkey """ q = select(self.DpapiBackupkey).filter(func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() if not len(results): pvk_encoded = base64.b64encode(pvk) @@ -708,7 +689,7 @@ def add_domain_backupkey(self, domain: str, pvk: bytes): # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.DpapiBackupkey) # .returning(self.DpapiBackupkey.c.id) - self.conn.execute(q, [backup_key]) # .scalar() + self.db_execute(q, [backup_key]) # .scalar() nxc_logger.debug(f"add_domain_backupkey(domain={domain}, pvk={pvk_encoded})") except Exception as e: nxc_logger.debug(f"Issue while inserting DPAPI Backup Key: {e}") @@ -721,7 +702,7 @@ def get_domain_backupkey(self, domain: Optional[str] = None): q = select(self.DpapiBackupkey) if domain is not None: q = q.filter(func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"get_domain_backupkey(domain={domain}) => {results}") @@ -735,19 +716,19 @@ def is_dpapi_secret_valid(self, dpapi_secret_id): :dpapi_secret_id is a primary id """ q = select(self.DpapiSecrets).filter(func.lower(self.DpapiSecrets.c.id) == dpapi_secret_id) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() valid = results is not None nxc_logger.debug(f"is_dpapi_secret_valid(groupID={dpapi_secret_id}) => {valid}") return valid def add_dpapi_secrets( - self, - host: str, - dpapi_type: str, - windows_user: str, - username: str, - password: str, - url: str = "", + self, + host: str, + dpapi_type: str, + windows_user: str, + username: str, + password: str, + url: str = "", ): """Add dpapi secrets to nxcdb""" secret = { @@ -760,31 +741,31 @@ def add_dpapi_secrets( } q = Insert(self.DpapiSecrets).on_conflict_do_nothing() # .returning(self.DpapiSecrets.c.id) - self.conn.execute(q, [secret]) # .scalar() - + self.db_execute(q, [secret]) # .scalar() - nxc_logger.debug(f"add_dpapi_secrets(host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, password={password}, url={url})") + nxc_logger.debug( + f"add_dpapi_secrets(host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, password={password}, url={url})") def get_dpapi_secrets( - self, - filter_term=None, - host: Optional[str] = None, - dpapi_type: Optional[str] = None, - windows_user: Optional[str] = None, - username: Optional[str] = None, - url: Optional[str] = None, + self, + filter_term=None, + host: Optional[str] = None, + dpapi_type: Optional[str] = None, + windows_user: Optional[str] = None, + username: Optional[str] = None, + url: Optional[str] = None, ): """Get dpapi secrets from nxcdb""" q = select(self.DpapiSecrets) if self.is_dpapi_secret_valid(filter_term): q = q.filter(self.DpapiSecrets.c.id == filter_term) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] elif host: q = q.filter(self.DpapiSecrets.c.host == host) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] elif dpapi_type: @@ -797,9 +778,10 @@ def get_dpapi_secrets( q = q.filter(func.lower(self.DpapiSecrets.c.windows_user).like(like_term)) elif url: q = q.filter(func.lower(self.DpapiSecrets.c.url) == func.lower(url)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() - nxc_logger.debug(f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}") + nxc_logger.debug( + f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}") return results def add_loggedin_relation(self, user_id, host_id): @@ -807,7 +789,7 @@ def add_loggedin_relation(self, user_id, host_id): self.LoggedinRelationsTable.c.userid == user_id, self.LoggedinRelationsTable.c.hostid == host_id, ) - results = self.conn.execute(relation_query).all() + results = self.db_execute(relation_query).all() # only add one if one doesn't already exist if not results: @@ -817,7 +799,7 @@ def add_loggedin_relation(self, user_id, host_id): # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id) - self.conn.execute(q, [relation]) # .scalar() + self.db_execute(q, [relation]) # .scalar() inserted_id_results = self.get_loggedin_relations(user_id, host_id) nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}") return inserted_id_results[0].id @@ -830,7 +812,7 @@ def get_loggedin_relations(self, user_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.userid == user_id) if host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - return self.conn.execute(q).all() + return self.db_execute(q).all() def remove_loggedin_relations(self, user_id=None, host_id=None): q = delete(self.LoggedinRelationsTable) @@ -838,15 +820,15 @@ def remove_loggedin_relations(self, user_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.userid == user_id) elif host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - self.conn.execute(q) + self.db_execute(q) def get_checks(self): q = select(self.ConfChecksTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_check_results(self): q = select(self.ConfChecksResultsTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def insert_data(self, table, select_results=None, **new_row): """ @@ -878,14 +860,14 @@ def insert_data(self, table, select_results=None, **new_row): q = Insert(table) # .returning(table.c.id) update_column = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=table.primary_key, set_=update_column) - self.conn.execute(q, results) # .scalar() + self.db_execute(q, results) # .scalar() # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted return updated_ids def add_check(self, name, description): """Check if this check item has already been added to the database, if not, add it in.""" q = select(self.ConfChecksTable).filter(self.ConfChecksTable.c.name == name) - select_results = self.conn.execute(q).all() + select_results = self.db_execute(q).all() context = locals() new_row = {column: context[column] for column in ("name", "description")} updated_ids = self.insert_data(self.ConfChecksTable, select_results, **new_row) @@ -896,8 +878,9 @@ def add_check(self, name, description): def add_check_result(self, host_id, check_id, secure, reasons): """Check if this check result has already been added to the database, if not, add it in.""" - q = select(self.ConfChecksResultsTable).filter(self.ConfChecksResultsTable.c.host_id == host_id, self.ConfChecksResultsTable.c.check_id == check_id) - select_results = self.conn.execute(q).all() + q = select(self.ConfChecksResultsTable).filter(self.ConfChecksResultsTable.c.host_id == host_id, + self.ConfChecksResultsTable.c.check_id == check_id) + select_results = self.db_execute(q).all() context = locals() new_row = {column: context[column] for column in ("host_id", "check_id", "secure", "reasons")} updated_ids = self.insert_data(self.ConfChecksResultsTable, select_results, **new_row) diff --git a/nxc/protocols/ssh/database.py b/nxc/protocols/ssh/database.py index f38ab45ae..1de410c4e 100644 --- a/nxc/protocols/ssh/database.py +++ b/nxc/protocols/ssh/database.py @@ -1,19 +1,17 @@ +import configparser +import os +import sys + +from sqlalchemy import Table, select, func, delete from sqlalchemy.dialects.sqlite import Insert -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy import MetaData, Table, select, func, delete from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) -import os -from pathlib import Path -import configparser - +from nxc.database import BaseDB from nxc.logger import nxc_logger from nxc.paths import NXC_PATH -import sys # we can't import config.py due to a circular dependency, so we have to create redundant code unfortunately nxc_config = configparser.ConfigParser() @@ -21,23 +19,14 @@ nxc_workspace = nxc_config.get("nxc", "workspace", fallback="default") -class database: +class database(BaseDB): def __init__(self, db_engine): self.CredentialsTable = None self.HostsTable = None self.LoggedinRelationsTable = None self.AdminRelationsTable = None self.KeysTable = None - - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - self.sess = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -105,26 +94,13 @@ def reflect_tables(self): ) sys.exit() - def shutdown_db(self): - try: - self.sess.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.sess.execute(table.delete()) - def add_host(self, host, port, banner, os=None): """Check if this host has already been added to the database, if not, add it in.""" hosts = [] updated_ids = [] q = select(self.HostsTable).filter(self.HostsTable.c.host == host) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"add_host(): Initial hosts results: {results}") # create new host @@ -162,7 +138,7 @@ def add_host(self, host, port, banner, os=None): update_columns = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns) - self.sess.execute(q, hosts) # .scalar() + self.db_execute(q, hosts) # .scalar() # we only return updated IDs for now - when RETURNING clause is allowed we can return inserted if updated_ids: nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}") @@ -183,13 +159,13 @@ def add_credential(self, credtype, username, password, key=None): self.KeysTable.c.data == key, ) ) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() else: q = select(self.CredentialsTable).filter( func.lower(self.CredentialsTable.c.username) == func.lower(username), func.lower(self.CredentialsTable.c.credtype) == func.lower(credtype), ) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() # add new credential if not results: @@ -218,10 +194,11 @@ def add_credential(self, credtype, username, password, key=None): # TODO: find a way to abstract this away to a single Upsert call q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id) update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"} - q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users) + q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, + set_=update_columns_users) nxc_logger.debug(f"Adding credentials: {credentials}") - self.sess.execute(q_users, credentials) # .scalar() + self.db_execute(q_users, credentials) # .scalar() # hacky way to get cred_id since we can't use returning() yet if len(credentials) == 1: @@ -238,19 +215,19 @@ def remove_credentials(self, creds_id): for cred_id in creds_id: q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id) del_hosts.append(q) - self.sess.execute(q) + self.db_execute(q) def add_key(self, cred_id, key): # check if key relation already exists - check_q = self.sess.execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all() + check_q = self.db_execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all() nxc_logger.debug(f"check_q: {check_q}") if check_q: nxc_logger.debug(f"Key already exists for cred_id {cred_id}") return None key_data = {"credid": cred_id, "data": key} - self.sess.execute(Insert(self.KeysTable), key_data) - key_id = self.sess.execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all()[0].id + self.db_execute(Insert(self.KeysTable), key_data) + key_id = self.db_execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all()[0].id nxc_logger.debug(f"Key added: {key_id}") return key_id @@ -260,7 +237,7 @@ def get_keys(self, key_id=None, cred_id=None): q = q.filter(self.KeysTable.c.id == key_id) elif cred_id is not None: q = q.filter(self.KeysTable.c.credid == cred_id) - return self.sess.execute(q).all() + return self.db_execute(q).all() def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None): add_links = [] @@ -274,7 +251,7 @@ def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None) func.lower(self.CredentialsTable.c.username) == func.lower(username), self.CredentialsTable.c.password == secret, ) - creds = self.sess.execute(creds_q) + creds = self.db_execute(creds_q) hosts = self.get_hosts(host_id) if creds and hosts: @@ -286,7 +263,7 @@ def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None) self.AdminRelationsTable.c.credid == cred_id, self.AdminRelationsTable.c.hostid == host_id, ) - links = self.sess.execute(admin_relations_select).all() + links = self.db_execute(admin_relations_select).all() if not links: add_links.append(link) @@ -294,7 +271,7 @@ def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None) admin_relations_insert = Insert(self.AdminRelationsTable) if add_links: - self.sess.execute(admin_relations_insert, add_links) + self.db_execute(admin_relations_insert, add_links) def get_admin_relations(self, cred_id=None, host_id=None): if cred_id: @@ -304,7 +281,7 @@ def get_admin_relations(self, cred_id=None, host_id=None): else: q = select(self.AdminRelationsTable) - return self.sess.execute(q).all() + return self.db_execute(q).all() def remove_admin_relation(self, cred_ids=None, host_ids=None): q = delete(self.AdminRelationsTable) @@ -314,7 +291,7 @@ def remove_admin_relation(self, cred_ids=None, host_ids=None): elif host_ids: for host_id in host_ids: q = q.filter(self.AdminRelationsTable.c.hostid == host_id) - self.sess.execute(q) + self.db_execute(q) def is_credential_valid(self, credential_id): """Check if this credential ID is valid.""" @@ -322,7 +299,7 @@ def is_credential_valid(self, credential_id): self.CredentialsTable.c.id == credential_id, self.CredentialsTable.c.password is not None, ) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -340,7 +317,7 @@ def get_credentials(self, filter_term=None, cred_type=None): else: q = select(self.CredentialsTable) - return self.sess.execute(q).all() + return self.db_execute(q).all() def get_credential(self, cred_type, username, password): q = select(self.CredentialsTable).filter( @@ -348,14 +325,14 @@ def get_credential(self, cred_type, username, password): self.CredentialsTable.c.password == password, self.CredentialsTable.c.credtype == cred_type, ) - results = self.sess.execute(q).first() + results = self.db_execute(q).first() if results is not None: return results.id def is_host_valid(self, host_id): """Check if this host ID is valid.""" q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None): @@ -365,21 +342,21 @@ def get_hosts(self, filter_term=None): # if we're returning a single host by ID if self.is_host_valid(filter_term): q = q.filter(self.HostsTable.c.id == filter_term) - results = self.sess.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by host elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(self.HostsTable.c.host.like(like_term)) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"SSH get_hosts() - results: {results}") return results def is_user_valid(self, cred_id): """Check if this User ID is valid.""" q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id) - results = self.sess.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_users(self, filter_term=None): @@ -391,18 +368,18 @@ def get_users(self, filter_term=None): elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term)) - return self.sess.execute(q).all() + return self.db_execute(q).all() def get_user(self, domain, username): q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username)) - return self.sess.execute(q).all() + return self.db_execute(q).all() def add_loggedin_relation(self, cred_id, host_id, shell=False): relation_query = select(self.LoggedinRelationsTable).filter( self.LoggedinRelationsTable.c.credid == cred_id, self.LoggedinRelationsTable.c.hostid == host_id, ) - results = self.sess.execute(relation_query).all() + results = self.db_execute(relation_query).all() # only add one if one doesn't already exist if not results: @@ -412,7 +389,7 @@ def add_loggedin_relation(self, cred_id, host_id, shell=False): # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id) - self.sess.execute(q, [relation]) # .scalar() + self.db_execute(q, [relation]) # .scalar() inserted_id_results = self.get_loggedin_relations(cred_id, host_id) nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}") return inserted_id_results[0].id @@ -427,7 +404,7 @@ def get_loggedin_relations(self, cred_id=None, host_id=None, shell=None): q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) if shell: q = q.filter(self.LoggedinRelationsTable.c.shell == shell) - return self.sess.execute(q).all() + return self.db_execute(q).all() def remove_loggedin_relations(self, cred_id=None, host_id=None): q = delete(self.LoggedinRelationsTable) @@ -435,4 +412,4 @@ def remove_loggedin_relations(self, cred_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id) elif host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - self.sess.execute(q) + self.db_execute(q) diff --git a/nxc/protocols/vnc/database.py b/nxc/protocols/vnc/database.py index 0be660c78..4f6e056e2 100644 --- a/nxc/protocols/vnc/database.py +++ b/nxc/protocols/vnc/database.py @@ -1,36 +1,25 @@ -from pathlib import Path -from sqlalchemy import MetaData, Table +import sys +import warnings + +from sqlalchemy import Table from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) -from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.exc import SAWarning -import warnings -from nxc.logger import nxc_logger -import sys +from nxc.database import BaseDB # if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings warnings.filterwarnings("ignore", category=SAWarning) -class database: +class database(BaseDB): def __init__(self, db_engine): self.HostsTable = None self.CredentialsTable = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - # this is still named "conn" when it is the session object; TODO: rename - self.conn = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -67,16 +56,3 @@ def reflect_tables(self): [-] Then remove the {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB""" ) sys.exit() - - def shutdown_db(self): - try: - self.conn.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.conn.execute(table.delete()) diff --git a/nxc/protocols/winrm/database.py b/nxc/protocols/winrm/database.py index d361ae85b..b8d01bfce 100644 --- a/nxc/protocols/winrm/database.py +++ b/nxc/protocols/winrm/database.py @@ -1,33 +1,25 @@ -from pathlib import Path from sqlalchemy.dialects.sqlite import Insert -from sqlalchemy.orm import sessionmaker, scoped_session -from sqlalchemy import MetaData, Table, select, func, delete +import sys + +from sqlalchemy import Table, select, func, delete +from sqlalchemy.dialects.sqlite import Insert from sqlalchemy.exc import ( - IllegalStateChangeError, NoInspectionAvailable, NoSuchTableError, ) + +from nxc.database import BaseDB from nxc.logger import nxc_logger -import sys -class database: +class database(BaseDB): def __init__(self, db_engine): self.HostsTable = None self.UsersTable = None self.AdminRelationsTable = None self.LoggedinRelationsTable = None - self.db_engine = db_engine - self.db_path = self.db_engine.url.database - self.protocol = Path(self.db_path).stem.upper() - self.metadata = MetaData() - self.reflect_tables() - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - - Session = scoped_session(session_factory) - # this is still named "conn" when it is the session object; TODO: rename - self.conn = Session() + super().__init__(db_engine) @staticmethod def db_schema(db_conn): @@ -88,19 +80,6 @@ def reflect_tables(self): ) sys.exit() - def shutdown_db(self): - try: - self.conn.close() - # due to the async nature of nxc, sometimes session state is a bit messy and this will throw: - # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and - # this would cause an unexpected state change to - except IllegalStateChangeError as e: - nxc_logger.debug(f"Error while closing session db object: {e}") - - def clear_database(self): - for table in self.metadata.sorted_tables: - self.conn.execute(table.delete()) - def add_host(self, ip, port, hostname, domain, os=None): """ Check if this host has already been added to the database, if not, add it in. @@ -110,7 +89,7 @@ def add_host(self, ip, port, hostname, domain, os=None): hosts = [] q = select(self.HostsTable).filter(self.HostsTable.c.ip == ip) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"smb add_host() - hosts returned: {results}") # create new host @@ -147,7 +126,7 @@ def add_host(self, ip, port, hostname, domain, os=None): q = Insert(self.HostsTable) update_columns = {col.name: col for col in q.excluded if col.name not in "id"} q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns) - self.conn.execute(q, hosts) + self.db_execute(q, hosts) def add_credential(self, credtype, domain, username, password, pillaged_from=None): """Check if this credential has already been added to the database, if not add it in.""" @@ -171,7 +150,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non func.lower(self.UsersTable.c.username) == func.lower(username), func.lower(self.UsersTable.c.credtype) == func.lower(credtype), ) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() # add new credential if not results: @@ -207,7 +186,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non q_users = Insert(self.UsersTable) # .returning(self.UsersTable.c.id) update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"} q_users = q_users.on_conflict_do_update(index_elements=self.UsersTable.primary_key, set_=update_columns_users) - self.conn.execute(q_users, credentials) # .scalar() + self.db_execute(q_users, credentials) # .scalar() def remove_credentials(self, creds_id): """Removes a credential ID from the database""" @@ -215,7 +194,7 @@ def remove_credentials(self, creds_id): for cred_id in creds_id: q = delete(self.UsersTable).filter(self.UsersTable.c.id == cred_id) del_hosts.append(q) - self.conn.execute(q) + self.db_execute(q) def add_admin_user(self, credtype, domain, username, password, host, user_id=None): domain = domain.split(".")[0] @@ -231,7 +210,7 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password, ) - users = self.conn.execute(creds_q) + users = self.db_execute(creds_q) hosts = self.get_hosts(host) if users and hosts: @@ -243,14 +222,14 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non self.AdminRelationsTable.c.userid == user_id, self.AdminRelationsTable.c.hostid == host_id, ) - links = self.conn.execute(admin_relations_select).all() + links = self.db_execute(admin_relations_select).all() if not links: add_links.append(link) admin_relations_insert = Insert(self.AdminRelationsTable) - self.conn.execute(admin_relations_insert, add_links) + self.db_execute(admin_relations_insert, add_links) def get_admin_relations(self, user_id=None, host_id=None): if user_id: @@ -260,7 +239,7 @@ def get_admin_relations(self, user_id=None, host_id=None): else: q = select(self.AdminRelationsTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def remove_admin_relation(self, user_ids=None, host_ids=None): q = delete(self.AdminRelationsTable) @@ -270,7 +249,7 @@ def remove_admin_relation(self, user_ids=None, host_ids=None): elif host_ids: for host_id in host_ids: q = q.filter(self.AdminRelationsTable.c.hostid == host_id) - self.conn.execute(q) + self.db_execute(q) def is_credential_valid(self, credential_id): """Check if this credential ID is valid.""" @@ -278,7 +257,7 @@ def is_credential_valid(self, credential_id): self.UsersTable.c.id == credential_id, self.UsersTable.c.password is not None, ) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_credentials(self, filter_term=None, cred_type=None): @@ -296,22 +275,22 @@ def get_credentials(self, filter_term=None, cred_type=None): else: q = select(self.UsersTable) - return self.conn.execute(q).all() + return self.db_execute(q).all() def is_credential_local(self, credential_id): q = select(self.UsersTable.c.domain).filter(self.UsersTable.c.id == credential_id) - user_domain = self.conn.execute(q).all() + user_domain = self.db_execute(q).all() if user_domain: q = select(self.HostsTable).filter(func.lower(self.HostsTable.c.id) == func.lower(user_domain)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def is_host_valid(self, host_id): """Check if this host ID is valid.""" q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_hosts(self, filter_term=None): @@ -321,7 +300,7 @@ def get_hosts(self, filter_term=None): # if we're returning a single host by ID if self.is_host_valid(filter_term): q = q.filter(self.HostsTable.c.id == filter_term) - results = self.conn.execute(q).first() + results = self.db_execute(q).first() # all() returns a list, so we keep the return format the same so consumers don't have to guess return [results] # if we're filtering by domain controllers @@ -333,14 +312,14 @@ def get_hosts(self, filter_term=None): elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term)) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() nxc_logger.debug(f"winrm get_hosts() - results: {results}") return results def is_user_valid(self, user_id): """Check if this User ID is valid.""" q = select(self.UsersTable).filter(self.UsersTable.c.id == user_id) - results = self.conn.execute(q).all() + results = self.db_execute(q).all() return len(results) > 0 def get_users(self, filter_term=None): @@ -352,21 +331,21 @@ def get_users(self, filter_term=None): elif filter_term and filter_term != "": like_term = func.lower(f"%{filter_term}%") q = q.filter(func.lower(self.UsersTable.c.username).like(like_term)) - return self.conn.execute(q).all() + return self.db_execute(q).all() def get_user(self, domain, username): q = select(self.UsersTable).filter( func.lower(self.UsersTable.c.domain) == func.lower(domain), func.lower(self.UsersTable.c.username) == func.lower(username), ) - return self.conn.execute(q).all() + return self.db_execute(q).all() def add_loggedin_relation(self, user_id, host_id): relation_query = select(self.LoggedinRelationsTable).filter( self.LoggedinRelationsTable.c.userid == user_id, self.LoggedinRelationsTable.c.hostid == host_id, ) - results = self.conn.execute(relation_query).all() + results = self.db_execute(relation_query).all() # only add one if one doesn't already exist if not results: @@ -375,7 +354,7 @@ def add_loggedin_relation(self, user_id, host_id): # TODO: find a way to abstract this away to a single Upsert call q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id) - self.conn.execute(q, [relation]) # .scalar() + self.db_execute(q, [relation]) # .scalar() except Exception as e: nxc_logger.debug(f"Error inserting LoggedinRelation: {e}") @@ -385,7 +364,7 @@ def get_loggedin_relations(self, user_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.userid == user_id) if host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - return self.conn.execute(q).all() + return self.db_execute(q).all() def remove_loggedin_relations(self, user_id=None, host_id=None): q = delete(self.LoggedinRelationsTable) @@ -393,4 +372,4 @@ def remove_loggedin_relations(self, user_id=None, host_id=None): q = q.filter(self.LoggedinRelationsTable.c.userid == user_id) elif host_id: q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id) - self.conn.execute(q) + self.db_execute(q)