Skip to content

Commit

Permalink
Allowing to configure "logger" in tds::printReplies function; both fo…
Browse files Browse the repository at this point in the history
…r error and info. (#1795)

Change mssqlattack.py and mssqlshell.py to align with changes
  • Loading branch information
gabrielg5 authored Aug 27, 2024
1 parent 2509ca1 commit 0656b48
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 52 deletions.
58 changes: 19 additions & 39 deletions impacket/examples/mssqlshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,6 @@
import cmd
import sys

def handle_lastError(f):
def wrapper(*args):
try:
f(*args)
finally:
if(args[0].sql.lastError):
print(args[0].sql.lastError)
return wrapper

class SQLSHELL(cmd.Cmd):
def __init__(self, SQL, show_queries=False, tcpShell=None):
if tcpShell is not None:
Expand All @@ -49,6 +40,10 @@ def __init__(self, SQL, show_queries=False, tcpShell=None):
self.set_prompt()
self.intro = '[!] Press help for extra shell commands'

def print_replies(self):
# to condense all calls to sql.printReplies with right logger in this context
self.sql.printReplies(error_logger=print, info_logger=print)

def do_help(self, line):
print("""
lcd {path} - changes the current local directory to {path}
Expand Down Expand Up @@ -103,19 +98,16 @@ def execute_as(self, exec_as):
self.at.append((at, exec_as))
else:
self.sql_query(exec_as)
self.sql.printReplies()
self.print_replies()

@handle_lastError
def do_exec_as_login(self, s):
exec_as = "execute as login='%s';" % s
self.execute_as(exec_as)

@handle_lastError
def do_exec_as_user(self, s):
exec_as = "execute as user='%s';" % s
self.execute_as(exec_as)

@handle_lastError
def do_use_link(self, s):
if s == 'localhost':
self.at = []
Expand All @@ -124,7 +116,7 @@ def do_use_link(self, s):
else:
self.at.append((s, ''))
row = self.sql_query('select system_user as "username"')
self.sql.printReplies()
self.print_replies()
if len(row) < 1:
self.at = self.at[:-1]

Expand All @@ -139,26 +131,23 @@ def sql_query(self, query, show=True):
def do_shell(self, s):
os.system(s)

@handle_lastError
def do_xp_dirtree(self, s):
try:
self.sql_query("exec master.sys.xp_dirtree '%s',1,1" % s)
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass

@handle_lastError
def do_xp_cmdshell(self, s):
try:
self.sql_query("exec master..xp_cmdshell '%s'" % s)
self.sql.printReplies()
self.print_replies()
self.sql.colMeta[0]['TypeData'] = 80*2
self.sql.printRows()
except:
pass

@handle_lastError
def do_sp_start_job(self, s):
try:
self.sql_query("DECLARE @job NVARCHAR(100);"
Expand All @@ -169,7 +158,7 @@ def do_sp_start_job(self, s):
"@subsystem='CMDEXEC',@command='%s',@on_success_action=1;"
"EXEC msdb..sp_add_jobserver @job_name=@job;"
"EXEC msdb..sp_start_job @job_name=@job;" % s)
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass
Expand All @@ -180,60 +169,53 @@ def do_lcd(self, s):
else:
os.chdir(s)

@handle_lastError
def do_enable_xp_cmdshell(self, line):
try:
self.sql_query("exec master.dbo.sp_configure 'show advanced options',1;RECONFIGURE;"
"exec master.dbo.sp_configure 'xp_cmdshell', 1;RECONFIGURE;")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass

@handle_lastError
def do_disable_xp_cmdshell(self, line):
try:
self.sql_query("exec sp_configure 'xp_cmdshell', 0 ;RECONFIGURE;exec sp_configure "
"'show advanced options', 0 ;RECONFIGURE;")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass

@handle_lastError
def do_enum_links(self, line):
self.sql_query("EXEC sp_linkedservers")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
self.sql_query("EXEC sp_helplinkedsrvlogin")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()

@handle_lastError
def do_enum_users(self, line):
self.sql_query("EXEC sp_helpuser")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()

@handle_lastError
def do_enum_db(self, line):
try:
self.sql_query("select name, is_trustworthy_on from sys.databases")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass

@handle_lastError
def do_enum_owner(self, line):
try:
self.sql_query("SELECT name [Database], suser_sname(owner_sid) [Owner] FROM sys.databases")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass

@handle_lastError
def do_enum_impersonate(self, line):
old_db = self.sql.currentDB
try:
Expand All @@ -260,31 +242,29 @@ def do_enum_impersonate(self, line):
" ON pe.grantor_principal_id = pr2.principal_Id "
"WHERE pe.type = 'IM'")
result.extend(result_rows)
self.sql.printReplies()
self.print_replies()
self.sql.rows = result
self.sql.printRows()
except:
pass
finally:
self.sql_query("use " + old_db)

@handle_lastError
def do_enum_logins(self, line):
try:
self.sql_query("select r.name,r.type_desc,r.is_disabled, sl.sysadmin, sl.securityadmin, "
"sl.serveradmin, sl.setupadmin, sl.processadmin, sl.diskadmin, sl.dbcreator, "
"sl.bulkadmin from master.sys.server_principals r left join master.sys.syslogins sl "
"on sl.sid = r.sid where r.type in ('S','E','X','U','G')")
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass

@handle_lastError
def default(self, line):
try:
self.sql_query(line)
self.sql.printReplies()
self.print_replies()
self.sql.printRows()
except:
pass
Expand Down
10 changes: 3 additions & 7 deletions impacket/examples/ntlmrelayx/attacks/mssqlattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,9 @@ def run(self):
if self.config.queries is not None:
for query in self.config.queries:
LOG.info('Executing SQL: %s' % query)
try:
self.client.sql_query(query)
self.client.printReplies()
self.client.printRows()
finally:
if(self.client.lastError):
print(self.client.lastError)
self.client.sql_query(query)
self.client.printReplies()
self.client.printRows()
else:
LOG.error('No SQL queries specified for MSSQL relay!')

12 changes: 6 additions & 6 deletions impacket/tds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,18 +1015,18 @@ def printRows(self):
self.__rowsPrinter.logMessage(col['Format'] % row[col['Name']] + self.COL_SEPARATOR)
self.__rowsPrinter.logMessage('\n')

def printReplies(self):
def printReplies(self, error_logger=LOG.error, info_logger=LOG.info):
for keys in list(self.replies.keys()):
for i, key in enumerate(self.replies[keys]):
if key['TokenType'] == TDS_ERROR_TOKEN:
error = "ERROR(%s): Line %d: %s" % (key['ServerName'].decode('utf-16le'), key['LineNumber'], key['MsgText'].decode('utf-16le'))
self.lastError = SQLErrorException("ERROR: Line %d: %s" % (key['LineNumber'], key['MsgText'].decode('utf-16le')))
self.lastError = SQLErrorException("ERROR(%s): Line %d: %s" % (key['ServerName'].decode('utf-16le'), key['LineNumber'], key['MsgText'].decode('utf-16le')))
error_logger(self.lastError)

elif key['TokenType'] == TDS_INFO_TOKEN:
LOG.info("INFO(%s): Line %d: %s" % (key['ServerName'].decode('utf-16le'), key['LineNumber'], key['MsgText'].decode('utf-16le')))
info_logger("INFO(%s): Line %d: %s" % (key['ServerName'].decode('utf-16le'), key['LineNumber'], key['MsgText'].decode('utf-16le')))

elif key['TokenType'] == TDS_LOGINACK_TOKEN:
LOG.info("ACK: Result: %s - %s (%d%d %d%d) " % (key['Interface'], key['ProgName'].decode('utf-16le'), key['MajorVer'], key['MinorVer'], key['BuildNumHi'], key['BuildNumLow']))
info_logger("ACK: Result: %s - %s (%d%d %d%d) " % (key['Interface'], key['ProgName'].decode('utf-16le'), key['MajorVer'], key['MinorVer'], key['BuildNumHi'], key['BuildNumLow']))

elif key['TokenType'] == TDS_ENVCHANGE_TOKEN:
if key['Type'] in (TDS_ENVCHANGE_DATABASE, TDS_ENVCHANGE_LANGUAGE, TDS_ENVCHANGE_CHARSET, TDS_ENVCHANGE_PACKETSIZE):
Expand All @@ -1045,7 +1045,7 @@ def printReplies(self):
_type = 'PACKETSIZE'
else:
_type = "%d" % key['Type']
LOG.info("ENVCHANGE(%s): Old Value: %s, New Value: %s" % (_type,record['OldValue'].decode('utf-16le'), record['NewValue'].decode('utf-16le')))
info_logger("ENVCHANGE(%s): Old Value: %s, New Value: %s" % (_type,record['OldValue'].decode('utf-16le'), record['NewValue'].decode('utf-16le')))

def parseRow(self,token,tuplemode=False):
# TODO: This REALLY needs to be improved. Right now we don't support correctly all the data types
Expand Down

0 comments on commit 0656b48

Please sign in to comment.