Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLA redirection #308

Merged
merged 5 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions pyrdp/mitm/RDPMITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ class RDPMITM:
Main MITM class. The job of this class is to orchestrate the components for all the protocols.
"""

def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, config: MITMConfig, state: RDPMITMState=None, recorder: Recorder=None):
def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, config: MITMConfig, state: RDPMITMState = None, recorder: Recorder = None):
"""
:param log: base logger to use for the connection
:param mainLogger: base logger to use for the connection
:param config: the MITM configuration
"""

Expand Down Expand Up @@ -94,11 +94,13 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf
self.channelMITMs = {}
"""MITM components for virtual channels"""

serverConnector = self.connectToServer()
self.tcp = TCPMITM(self.client.tcp, self.server.tcp, self.player.tcp, self.getLog("tcp"), self.state, self.recorder, serverConnector, self.statCounter)
self.onTlsReady = None
"""Callback for when TLS is done"""

self.tcp = TCPMITM(self.client.tcp, self.server.tcp, self.player.tcp, self.getLog("tcp"), self.state, self.recorder, self.statCounter)
"""TCP MITM component"""

self.x224 = X224MITM(self.client.x224, self.server.x224, self.getLog("x224"), self.state, serverConnector, self.startTLS)
self.x224 = X224MITM(self.client.x224, self.server.x224, self.getLog("x224"), self.state, self.connectToServer, self.disconnectFromServer, self.startTLS)
"""X224 MITM component"""

self.mcs = MCSMITM(self.client.mcs, self.server.mcs, self.state, self.recorder, self.buildChannel, self.getLog("mcs"), self.statCounter)
Expand Down Expand Up @@ -140,15 +142,17 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf

if config.recordReplays:
date = datetime.datetime.now()
replayFileName = "rdp_replay_{}_{}_{}.pyrdp"\
.format(date.strftime('%Y%m%d_%H-%M-%S'),
date.microsecond // 1000,
self.state.sessionID)
replayFileName = f"rdp_replay_{date.strftime('%Y%m%d_%H-%M-%S')}_{date.microsecond // 1000}_{self.state.sessionID}.pyrdp"
self.recorder.setRecordFilename(replayFileName)
self.recorder.addTransport(FileLayer(self.config.replayDir / replayFileName))

if config.enableCrawler:
self.crawler: FileCrawlerMITM = FileCrawlerMITM(self.getClientLog(MCSChannelName.DEVICE_REDIRECTION).createChild("crawler"), crawlerLogger, self.config, self.state)
self.crawler: FileCrawlerMITM = FileCrawlerMITM(
self.getClientLog(MCSChannelName.DEVICE_REDIRECTION).createChild("crawler"),
crawlerLogger,
self.config,
self.state
)

def getProtocol(self) -> Protocol:
"""
Expand Down Expand Up @@ -177,6 +181,10 @@ def getServerLog(self, name: str) -> SessionLogger:
"""
return self.serverLog.createChild(name)

def disconnectFromServer(self):
self.server.replaceTCP()
self.tcp.setServer(self.server.tcp)

async def connectToServer(self):
"""
Coroutine that connects to the target RDP server and the attacker.
Expand All @@ -185,20 +193,20 @@ async def connectToServer(self):

serverFactory = AwaitableClientFactory(self.server.tcp)
if self.config.transparent:
src = self.client.tcp.transport.client
if self.config.targetHost:
# Fully Transparent (with a specific poisoned target.)
connectTransparent(self.config.targetHost, self.config.targetPort, serverFactory, bindAddress=(src[0], 0))
if self.state.effectiveTargetHost:
# Fully Transparent (with a specific poisoned target, or when using redirection)
src = self.client.tcp.transport.client
connectTransparent(self.state.effectiveTargetHost, self.state.effectiveTargetPort, serverFactory, bindAddress=(src[0], 0))
else:
# Half Transparent (for client-side only)
dst = self.client.tcp.transport.getHost().host
reactor.connectTCP(dst, self.config.targetPort, serverFactory)
reactor.connectTCP(dst, self.state.effectiveTargetPort, serverFactory)
else:
reactor.connectTCP(self.config.targetHost, self.config.targetPort, serverFactory)
reactor.connectTCP(self.state.effectiveTargetHost, self.state.effectiveTargetPort, serverFactory)

await serverFactory.connected.wait()

if self.config.attackerHost is not None and self.config.attackerPort is not None:
if self.config.attackerHost is not None and self.config.attackerPort is not None and not self.player.tcp.connectedEvent.is_set():
attackerFactory = AwaitableClientFactory(self.player.tcp)
reactor.connectTCP(self.config.attackerHost, self.config.attackerPort, attackerFactory)

Expand Down
37 changes: 20 additions & 17 deletions pyrdp/mitm/TCPMITM.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#
# This file is part of the PyRDP project.
# Copyright (C) 2019 GoSecure Inc.
# Copyright (C) 2019-2021 GoSecure Inc.
# Licensed under the GPLv3 or later.
#
import time
from logging import LoggerAdapter
from typing import Coroutine

from pyrdp.layer import TwistedTCPLayer
from pyrdp.logging.StatCounter import StatCounter
Expand All @@ -20,46 +18,52 @@ class TCPMITM:
"""

def __init__(self, client: TwistedTCPLayer, server: TwistedTCPLayer, attacker: TwistedTCPLayer, log: LoggerAdapter,
state: RDPMITMState, recorder: Recorder, serverConnector: Coroutine, statCounter: StatCounter):
state: RDPMITMState, recorder: Recorder, statCounter: StatCounter):
"""
:param client: TCP layer for the client side
:param server: TCP layer for the server side
:param attacker: TCP layer for the attacker side
:param log: logger for this component
:param recorder: recorder for this connection
:param serverConnector: coroutine that connects to the server side, closed when the client disconnects
"""

self.statCounter = statCounter
# To keep track of useful statistics for the connection.

self.client = client
self.server = server
self.server = None
self.attacker = attacker
self.log = log
self.state = state
self.recorder = recorder
self.serverConnector = serverConnector

# Allows a lower layer to raise error tagged with the correct sessionID
self.client.log = log
self.server.log = log

self.clientObserver = self.client.createObserver(
onConnection = self.onClientConnection,
onDisconnection = self.onClientDisconnection,
)

self.serverObserver = self.server.createObserver(
onConnection = self.onServerConnection,
onDisconnection = self.onServerDisconnection,
)

self.attacker.createObserver(
onConnection = self.onAttackerConnection,
onDisconnection = self.onAttackerDisconnection,
)

self.serverObserver = None
self.setServer(server)

def setServer(self, server: TwistedTCPLayer):
if self.server is not None:
self.server.removeObserver(self.serverObserver)
self.server.disconnect(True)

self.server = server
self.server.log = self.log
self.serverObserver = self.server.createObserver(
onConnection=self.onServerConnection,
onDisconnection=self.onServerDisconnection,
)

def detach(self):
"""
Remove the observers from the layers.
Expand All @@ -75,7 +79,6 @@ def onClientConnection(self):

# Statistics
self.statCounter.start()
self.connectionTime = time.time()

ip = self.client.transport.client[0]
self.log.info("New client connected from %(clientIp)s", {"clientIp": ip})
Expand All @@ -94,7 +97,7 @@ def onClientDisconnection(self, reason):
self.recorder.recordFilename})
else:
self.statCounter.logReport(self.log)
self.serverConnector.close()

self.server.disconnect(True)

# For the attacker, we want to make sure we don't abort the connection to make sure that the close event is sent
Expand Down Expand Up @@ -137,4 +140,4 @@ def onAttackerDisconnection(self, reason):

def recordConnectionClose(self):
pdu = PlayerConnectionClosePDU(self.recorder.getCurrentTimeStamp())
self.recorder.record(pdu, pdu.header)
self.recorder.record(pdu, pdu.header)
62 changes: 41 additions & 21 deletions pyrdp/mitm/X224MITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# Licensed under the GPLv3 or later.
#

import typing
from typing import Callable, Coroutine, Optional
from logging import LoggerAdapter

from pyrdp.core import defer
from pyrdp.enum import NegotiationFailureCode, NegotiationProtocols, NegotiationType, NegotiationRequestFlags
from pyrdp.enum import NegotiationFailureCode, NegotiationType, NegotiationRequestFlags
from pyrdp.layer import X224Layer
from pyrdp.mitm.state import RDPMITMState
from pyrdp.parser import NegotiationRequestParser, NegotiationResponseParser
Expand All @@ -17,14 +17,17 @@


class X224MITM:
def __init__(self, client: X224Layer, server: X224Layer, log: LoggerAdapter, state: RDPMITMState, connector: typing.Coroutine, startTLSCallback: typing.Callable[[typing.Callable[[], None]], None]):
def __init__(self, client: X224Layer, server: X224Layer, log: LoggerAdapter, state: RDPMITMState,
connector: Callable[[], Coroutine], disconnector: Callable[[], None],
startTLSCallback: Callable[[Callable[[], None]], None]):
"""

:param client: X224 layer for the client side
:param server: X224 layer for the server side
:param log: logger for this component
:param state: state of the MITM
:param connector: coroutine that connects to the server, awaited when a connection request is received
:param connector: function that connects to the server, called when a connection request is received
:param disconnector: function that disconnects from the server, called when using a redirection host and NLA is enforced
:param startTLSCallback: callback that should execute a startTLS on the client and server sides
"""

Expand All @@ -34,8 +37,10 @@ def __init__(self, client: X224Layer, server: X224Layer, log: LoggerAdapter, sta
self.log = log
self.state = state
self.connector = connector
self.disconnector = disconnector
self.startTLSCallback = startTLSCallback
self.originalRequest: typing.Optional[NegotiationRequestPDU] = None
self.originalConnectionRequest: Optional[X224ConnectionRequestPDU] = None
self.originalNegotiationRequest: Optional[NegotiationRequestPDU] = None

self.client.createObserver(
onConnectionRequest = self.onConnectionRequest,
Expand All @@ -56,30 +61,31 @@ def onConnectionRequest(self, pdu: X224ConnectionRequestPDU):
"""

parser = NegotiationRequestParser()
self.originalRequest = parser.parse(pdu.payload)
self.state.requestedProtocols = self.originalRequest.requestedProtocols
self.originalConnectionRequest = pdu
self.originalNegotiationRequest = parser.parse(pdu.payload)
self.state.requestedProtocols = self.originalNegotiationRequest.requestedProtocols

if self.originalRequest.flags is not None and self.originalRequest.flags & NegotiationRequestFlags.RESTRICTED_ADMIN_MODE_REQUIRED:
if self.originalNegotiationRequest.flags is not None and self.originalNegotiationRequest.flags & NegotiationRequestFlags.RESTRICTED_ADMIN_MODE_REQUIRED:
self.log.warning("Client has enabled Restricted Admin Mode, which forces Network-Level Authentication (NLA)."
" Connection will fail.", {"restrictedAdminActivated": True})

if self.originalRequest.cookie:
self.log.info("%(cookie)s", {"cookie": self.originalRequest.cookie.decode()})
if self.originalNegotiationRequest.cookie:
self.log.info("%(cookie)s", {"cookie": self.originalNegotiationRequest.cookie.decode()})
else:
self.log.info("No cookie for this connection")

chosenProtocols = self.originalRequest.requestedProtocols
chosenProtocols = self.originalNegotiationRequest.requestedProtocols

if chosenProtocols is not None:
# Tell the server we only support the allowed authentication methods.
chosenProtocols &= self.state.config.authMethods

modifiedRequest = NegotiationRequestPDU(
self.originalRequest.cookie,
self.originalRequest.flags,
self.originalNegotiationRequest.cookie,
self.originalNegotiationRequest.flags,
chosenProtocols,
self.originalRequest.correlationFlags,
self.originalRequest.correlationID,
self.originalNegotiationRequest.correlationFlags,
self.originalNegotiationRequest.correlationID,
)

payload = parser.write(modifiedRequest)
Expand All @@ -90,13 +96,13 @@ async def connectToServer(self, payload: bytes):
Awaits the coroutine that connects to the server.
:param payload: the connection request payload
"""
await self.connector
await self.connector()
self.server.sendConnectionRequest(payload = payload)

def onConnectionConfirm(self, pdu: X224ConnectionConfirmPDU):
"""
Execute a startTLS if the SSL protocol was selected.
:param _: the connection confirm PDU
:param pdu: the connection confirm PDU
"""

# FIXME: In case the server picks anything other than what we support, PyRDP is
Expand All @@ -108,14 +114,28 @@ def onConnectionConfirm(self, pdu: X224ConnectionConfirmPDU):
parser = NegotiationResponseParser()
response = parser.parse(pdu.payload)
if isinstance(response, NegotiationFailurePDU):
self.log.info("The server failed the negotiation. Error: %(error)s", {"error": NegotiationFailureCode.getMessage(response.failureCode)})
payload = pdu.payload
if response.failureCode == NegotiationFailureCode.HYBRID_REQUIRED_BY_SERVER and self.state.canRedirect():
self.log.info("The server forces the use of NLA. Using redirection host: %(redirectionHost)s:%(redirectionPort)d", {
"redirectionHost": self.state.config.redirectionHost,
"redirectionPort": self.state.config.redirectionPort
})

# Disconnect from current server
self.disconnector()

# Use redirection host and replay sequence starting from the connection request
self.state.useRedirectionHost()
self.onConnectionRequest(self.originalConnectionRequest)
return
else:
self.log.info("The server failed the negotiation. Error: %(error)s", {"error": NegotiationFailureCode.getMessage(response.failureCode)})
payload = pdu.payload
else:
payload = parser.write(NegotiationResponsePDU(NegotiationType.TYPE_RDP_NEG_RSP, 0x00, response.selectedProtocols))

# FIXME: This should be done based on what authentication method the server selected, not on what
# the client supports.
if self.originalRequest.tlsSupported:
if self.originalNegotiationRequest.tlsSupported:
# If a TLS tunnel is requested, then we establish the server-side tunnel before
# replying to the client, so that we can clone the certificate if needed.
self.startTLSCallback(lambda: self.client.sendConnectionConfirm(payload, source=0x1234))
Expand All @@ -133,4 +153,4 @@ def onClientError(self, pdu: X224ErrorPDU):
self.server.sendPDU(pdu)

def onServerError(self, pdu: X224ErrorPDU):
self.client.sendPDU(pdu)
self.client.sendPDU(pdu)
8 changes: 8 additions & 0 deletions pyrdp/mitm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def buildArgParser():
"--transparent", help="Spoof source IP for connections to the server (See README)", action="store_true")
parser.add_argument("--no-gdi", help="Disable accelerated graphics pipeline (MS-RDPEGDI) extension",
action="store_true")
parser.add_argument("--nla-redirection-host", help="Redirection target ip if NLA is enforced", default=None)
parser.add_argument("--nla-redirection-port", help="Redirection target port if NLA is enforced", type=int, default=None)

return parser

Expand Down Expand Up @@ -163,6 +165,10 @@ def configure(cmdline=None) -> MITMConfig:
sys.stderr.write('error: A relay target is required unless running in transparent proxy mode.\n')
sys.exit(1)

if (args.nla_redirection_host is None) != (args.nla_redirection_port is None):
sys.stderr.write('Error: please provide both --nla-redirection-host and --nla-redirection-port')
sys.exit(1)

if args.target:
targetHost, targetPort = parseTarget(args.target)
else:
Expand Down Expand Up @@ -191,6 +197,8 @@ def configure(cmdline=None) -> MITMConfig:
config.extractFiles = not args.no_files
config.disableActiveClipboardStealing = args.disable_active_clipboard
config.useGdi = not args.no_gdi
config.redirectionHost = args.nla_redirection_host
config.redirectionPort = args.nla_redirection_port

payload = None
powershell = None
Expand Down
6 changes: 6 additions & 0 deletions pyrdp/mitm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def __init__(self):
self.authMethods: NegotiationProtocols = NegotiationProtocols.SSL
"""Specifies the list of authentication protocols that PyRDP accepts."""

self.redirectionHost = None
"""Host to redirect the connection to if NLA is enforced"""

self.redirectionPort = None
"""Port of the redirection host"""

@property
def replayDir(self) -> Path:
"""
Expand Down
7 changes: 6 additions & 1 deletion pyrdp/mitm/layerset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# This file is part of the PyRDP project.
# Copyright (C) 2019 GoSecure Inc.
# Copyright (C) 2019-2021 GoSecure Inc.
# Licensed under the GPLv3 or later.
#

Expand All @@ -27,3 +27,8 @@ def __init__(self):
self.tcp.setNext(self.segmentation)
self.segmentation.attachLayer(SegmentationPDUType.TPKT, self.tpkt)
LayerChainItem.chain(self.tpkt, self.x224, self.mcs)

def replaceTCP(self):
self.tcp = TwistedTCPLayer()
self.tcp.setNext(self.segmentation)
self.segmentation.attachLayer(SegmentationPDUType.TPKT, self.tpkt)
Loading