Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring to fix InterfaceError of DB #400

Merged
merged 6 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions nxc/database.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import sys
import configparser
import shutil
from sqlalchemy import create_engine
from sqlite3 import connect
import sys
from os import mkdir
from os.path import exists
from os.path import join as path_join
from pathlib import Path
from sqlite3 import connect
from threading import Lock

from sqlalchemy import create_engine, MetaData
from sqlalchemy.exc import IllegalStateChangeError
from sqlalchemy.orm import sessionmaker, scoped_session

from nxc.loaders.protocolloader import ProtocolLoader
from nxc.logger import nxc_logger
from nxc.paths import WORKSPACE_DIR


Expand Down Expand Up @@ -103,3 +109,39 @@ def initialize_db():

# Even if the default workspace exists, we still need to check if every protocol has a database (in case of a new protocol)
init_protocol_dbs("default")


class BaseDB:
def __init__(self, db_engine):
self.db_engine = db_engine
self.db_path = self.db_engine.url.database
self.protocol = Path(self.db_path).stem.upper()
self.metadata = MetaData()
self.reflect_tables()
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)

session = scoped_session(session_factory)
self.sess = session()
self.lock = Lock()

def reflect_tables(self):
raise NotImplementedError("Reflect tables not implemented")

def shutdown_db(self):
try:
self.sess.close()
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
nxc_logger.debug(f"Error while closing session db object: {e}")

def clear_database(self):
for table in self.metadata.sorted_tables:
self.db_execute(table.delete())

def db_execute(self, *args):
self.lock.acquire()
res = self.sess.execute(*args)
self.lock.release()
return res
77 changes: 29 additions & 48 deletions nxc/protocols/ftp/database.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
from pathlib import Path
import sys

from sqlalchemy import Table, select, delete, func
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy import MetaData, Table, select, delete, func
from sqlalchemy.exc import (
IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)

from nxc.database import BaseDB
from nxc.logger import nxc_logger
import sys


class database:
class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
self.LoggedinRelationsTable = None

self.db_engine = db_engine
self.db_path = self.db_engine.url.database
self.protocol = Path(self.db_path).stem.upper()
self.metadata = MetaData()
self.reflect_tables()

session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
Session = scoped_session(session_factory)
self.sess = Session()
super().__init__(db_engine)

@staticmethod
def db_schema(db_conn):
Expand Down Expand Up @@ -80,26 +72,13 @@ def reflect_tables(self):
)
sys.exit()

def shutdown_db(self):
try:
self.sess.close()
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
nxc_logger.debug(f"Error while closing session db object: {e}")

def clear_database(self):
for table in self.metadata.sorted_tables:
self.sess.execute(table.delete())

def add_host(self, host, port, banner):
"""Check if this host is already in the DB, if not add it"""
hosts = []
updated_ids = []

q = select(self.HostsTable).filter(self.HostsTable.c.host == host)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()

# create new host
if not results:
Expand Down Expand Up @@ -133,7 +112,7 @@ def add_host(self, host, port, banner):
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)

self.sess.execute(q, hosts) # .scalar()
self.db_execute(q, hosts) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}")
Expand All @@ -143,8 +122,9 @@ def add_credential(self, username, password):
"""Check if this credential has already been added to the database, if not add it in."""
credentials = []

q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username), func.lower(self.CredentialsTable.c.password) == func.lower(password))
results = self.sess.execute(q).all()
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username),
func.lower(self.CredentialsTable.c.password) == func.lower(password))
results = self.db_execute(q).all()

# add new credential
if not results:
Expand All @@ -170,10 +150,11 @@ def add_credential(self, username, password):
# TODO: find a way to abstract this away to a single Upsert call
q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id)
update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"}
q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users)
q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key,
set_=update_columns_users)
nxc_logger.debug(f"Adding credentials: {credentials}")

self.sess.execute(q_users, credentials) # .scalar()
self.db_execute(q_users, credentials) # .scalar()

# hacky way to get cred_id since we can't use returning() yet
if len(credentials) == 1:
Expand All @@ -187,23 +168,23 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
del_hosts.append(q)
self.sess.execute(q)
self.db_execute(q)

def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
q = select(self.CredentialsTable).filter(
self.CredentialsTable.c.id == credential_id,
self.CredentialsTable.c.password is not None,
)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
return len(results) > 0

def get_credential(self, username, password):
q = select(self.CredentialsTable).filter(
self.CredentialsTable.c.username == username,
self.CredentialsTable.c.password == password,
)
results = self.sess.execute(q).first()
results = self.db_execute(q).first()
if results is not None:
return results.id

Expand All @@ -220,12 +201,12 @@ def get_credentials(self, filter_term=None):
else:
q = select(self.CredentialsTable)

return self.sess.execute(q).all()
return self.db_execute(q).all()

def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
return len(results) > 0

def get_hosts(self, filter_term=None):
Expand All @@ -235,26 +216,26 @@ def get_hosts(self, filter_term=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
results = self.sess.execute(q).first()
results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by host
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(self.HostsTable.c.host.like(like_term))
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
nxc_logger.debug(f"FTP get_hosts() - results: {results}")
return results

def is_user_valid(self, cred_id):
"""Check if this User ID is valid."""
q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
return len(results) > 0

def get_user(self, username):
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username))
return self.sess.execute(q).all()
return self.db_execute(q).all()

def get_users(self, filter_term=None):
q = select(self.CredentialsTable)
Expand All @@ -265,14 +246,14 @@ def get_users(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term))
return self.sess.execute(q).all()
return self.db_execute(q).all()

def add_loggedin_relation(self, cred_id, host_id):
relation_query = select(self.LoggedinRelationsTable).filter(
self.LoggedinRelationsTable.c.credid == cred_id,
self.LoggedinRelationsTable.c.hostid == host_id,
)
results = self.sess.execute(relation_query).all()
results = self.db_execute(relation_query).all()

# only add one if one doesn't already exist
if not results:
Expand All @@ -282,7 +263,7 @@ def add_loggedin_relation(self, cred_id, host_id):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)

self.sess.execute(q, [relation]) # .scalar()
self.db_execute(q, [relation]) # .scalar()
inserted_id_results = self.get_loggedin_relations(cred_id, host_id)
nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}")
return inserted_id_results[0].id
Expand All @@ -295,15 +276,15 @@ def get_loggedin_relations(self, cred_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
if host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
return self.sess.execute(q).all()
return self.db_execute(q).all()

def remove_loggedin_relations(self, cred_id=None, host_id=None):
q = delete(self.LoggedinRelationsTable)
if cred_id:
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
elif host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
self.sess.execute(q)
self.db_execute(q)

def add_directory_listing(self, lir_id, data):
pass
Expand Down
37 changes: 7 additions & 30 deletions nxc/protocols/ldap/database.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
from pathlib import Path
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy import MetaData, Table
import sys

from sqlalchemy import Table
from sqlalchemy.exc import (
IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
from nxc.logger import nxc_logger
import sys

from nxc.database import BaseDB


class database:
class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None

self.db_engine = db_engine
self.db_path = self.db_engine.url.database
self.protocol = Path(self.db_path).stem.upper()
self.metadata = MetaData()
self.reflect_tables()
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)

Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
self.conn = Session()
super().__init__(db_engine)

@staticmethod
def db_schema(db_conn):
Expand Down Expand Up @@ -59,16 +49,3 @@ def reflect_tables(self):
[-] Then remove the nxc {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
)
sys.exit()

def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
nxc_logger.debug(f"Error while closing session db object: {e}")

def clear_database(self):
for table in self.metadata.sorted_tables:
self.conn.execute(table.delete())
Loading