diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 98351c76c..a31d6de07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,52 +35,56 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v1 - with: - python-version: '3.7' # Version range or exact version of a Python version to use, using semvers version range syntax. - architecture: 'x64' - - - name: Python version - run: python --version - - name: Pip version - run: pip --version - - - name: Install setuptools - run: sudo apt install python3-setuptools - - name: Install PyRDP dependencies - run: sudo apt install libdbus-1-dev libdbus-glib-1-dev libgl1-mesa-glx git python3-dev - - name: Install wheel - working-directory: . - run: pip install wheel - - name: Install PyRDP - working-directory: . - run: pip install -U -e .[full] - - - name: Install ci dependencies - run: pip install -r requirements-ci.txt - - - name: Extract test files - uses: DuckSoft/extract-7z-action@v1.0 - with: - pathSource: test/files/test_files.zip - pathTarget: test/files - - - name: Integration Test with a prerecorded PCAP. - working-directory: ./ - run: coverage run test/test_prerecorded.py - - - name: pyrdp-mitm.py initialization integration test - working-directory: ./ - run: coverage run --append test/test_mitm_initialization.py dummy_value - - - name: pyrdp-player.py read a replay in headless mode test - working-directory: ./ - run: coverage run --append bin/pyrdp-player.py --headless test/files/test_session.replay - - - name: Coverage - working-directory: ./ - run: coverage report --fail-under=40 + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: '3.7' # Version range or exact version of a Python version to use, using semvers version range syntax. + architecture: 'x64' + + - name: Python version + run: python --version + - name: Pip version + run: pip --version + + - name: Install setuptools + run: sudo apt install python3-setuptools + - name: Install PyRDP dependencies + run: sudo apt install libdbus-1-dev libdbus-glib-1-dev libgl1-mesa-glx git python3-dev + - name: Install wheel + working-directory: . + run: pip install wheel + - name: Install PyRDP + working-directory: . + run: pip install -U -e .[full] + + - name: Install ci dependencies + run: pip install -r requirements-ci.txt + + - name: Extract test files + uses: DuckSoft/extract-7z-action@v1.0 + with: + pathSource: test/files/test_files.zip + pathTarget: test/files + + - name: Integration Test with a prerecorded PCAP. + working-directory: ./ + run: coverage run test/test_prerecorded.py + + - name: pyrdp-mitm.py initialization integration test + working-directory: ./ + run: coverage run --append test/test_mitm_initialization.py dummy_value + + - name: pyrdp-player.py read a replay in headless mode test + working-directory: ./ + run: coverage run --append bin/pyrdp-player.py --headless test/files/test_session.replay + + - name: Run unit tests + working-directory: ./ + run: coverage run --append -m unittest discover -v + + - name: Coverage report + working-directory: ./ + run: coverage report --fail-under=40 @@ -104,6 +108,9 @@ jobs: - name: Install PyRDP working-directory: . run: pip install -U -e .[full] + - name: Install coverage + working-directory: . + run: pip install coverage - name: Extract test files uses: DuckSoft/extract-7z-action@v1.0 @@ -113,12 +120,20 @@ jobs: - name: Integration Test with a prerecorded PCAP. working-directory: ./ - run: python test/test_prerecorded.py + run: coverage run test/test_prerecorded.py - name: pyrdp-mitm.py initialization test working-directory: ./ - run: python test/test_mitm_initialization.py dummy_value + run: coverage run --append test/test_mitm_initialization.py dummy_value - name: pyrdp-player.py read a replay in headless mode test working-directory: ./ - run: python bin/pyrdp-player.py --headless test/files/test_session.replay + run: coverage run --append bin/pyrdp-player.py --headless test/files/test_session.replay + + - name: Run unit tests + working-directory: ./ + run: coverage run --append -m unittest discover -v + + - name: Coverage report + working-directory: ./ + run: coverage report --fail-under=40 diff --git a/.gitignore b/.gitignore index 8a4fab6f2..622b86f17 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ mitm.json # twisted /twisted/plugins/dropin.cache + +# code coverage +htmlcov/ diff --git a/bin/pyrdp-convert.py b/bin/pyrdp-convert.py index 349e95c25..701aae66e 100755 --- a/bin/pyrdp-convert.py +++ b/bin/pyrdp-convert.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): # We'll set up the recorder ourselves config.recordReplays = False - state = RDPMITMState(config) + state = RDPMITMState(config, log.sessionID) sink, outfile = getSink(format, output_path) transport = ConversionLayer(sink) if sink else FileLayer(outfile) diff --git a/pyrdp/mitm/ClipboardMITM.py b/pyrdp/mitm/ClipboardMITM.py index 8b6d8779a..29d1931fd 100644 --- a/pyrdp/mitm/ClipboardMITM.py +++ b/pyrdp/mitm/ClipboardMITM.py @@ -14,6 +14,7 @@ from pyrdp.enum import ClipboardFormatNumber, ClipboardMessageFlags, ClipboardMessageType, PlayerPDUType, FileContentsFlags from pyrdp.layer import ClipboardLayer from pyrdp.logging.StatCounter import StatCounter, STAT +from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import ClipboardPDU, FormatDataRequestPDU, FormatDataResponsePDU, FileContentsRequestPDU, FileContentsResponsePDU from pyrdp.parser.rdp.virtual_channel.clipboard import FileDescriptor from pyrdp.recording import Recorder @@ -32,7 +33,7 @@ class PassiveClipboardStealer: """ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, - statCounter: StatCounter): + statCounter: StatCounter, state: RDPMITMState): """ :param client: clipboard layer for the client side :param server: clipboard layer for the server side @@ -44,13 +45,14 @@ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: Clipboard self.server = server self.config = config self.log = log + self.state = state self.recorder = recorder self.forwardNextDataResponse = True self.files = [] self.transfers = {} self.timeouts = {} # Track active timeout monitoring tasks. - self.fileDir = f"{self.config.fileDir}/{self.log.sessionID}" + self.fileDir = f"{self.config.fileDir}/{self.state.sessionID}" self.client.createObserver( onPDUReceived = self.onClientPDUReceived, @@ -206,8 +208,8 @@ class ActiveClipboardStealer(PassiveClipboardStealer): """ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, - statCounter: StatCounter): - super().__init__(config, client, server, log, recorder, statCounter) + statCounter: StatCounter, state: RDPMITMState): + super().__init__(config, client, server, log, recorder, statCounter, state) def handlePDU(self, pdu: ClipboardPDU, destination: ClipboardLayer): """ diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 0c6ee804c..881c4fbf7 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -4,20 +4,17 @@ # Licensed under the GPLv3 or later. # -import hashlib -import json from logging import LoggerAdapter -from pathlib import Path from typing import Dict, Optional, Union -from pyrdp.core import FileProxy, ObservedBy, Observer, Subject -from pyrdp.enum import CreateOption, DeviceRedirectionPacketID, DeviceType, DirectoryAccessMask, FileAccessMask, FileAttributes, \ +from pyrdp.core import ObservedBy, Observer, Subject +from pyrdp.enum import CreateOption, DeviceRedirectionPacketID, DeviceType, DirectoryAccessMask, FileAccessMask, \ + FileAttributes, \ FileCreateDisposition, FileCreateOptions, FileShareAccess, FileSystemInformationClass, IOOperationSeverity, \ MajorFunction, MinorFunction from pyrdp.layer import DeviceRedirectionLayer from pyrdp.logging.StatCounter import StatCounter, STAT -from pyrdp.mitm.config import MITMConfig -from pyrdp.mitm.FileMapping import FileMapping, FileMappingDecoder, FileMappingEncoder +from pyrdp.mitm.FileMapping import FileMapping from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import DeviceAnnounce, DeviceCloseRequestPDU, DeviceCloseResponsePDU, DeviceCreateRequestPDU, \ DeviceCreateResponsePDU, DeviceDirectoryControlResponsePDU, DeviceIORequestPDU, DeviceIOResponsePDU, \ @@ -60,7 +57,8 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye :param client: device redirection layer for the client side :param server: device redirection layer for the server side :param log: logger for this component - :param config: MITM configuration + :param statCounter: stat counter object + :param state: shared RDP MITM state """ super().__init__() @@ -69,10 +67,8 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye self.state = state self.log = log self.statCounter = statCounter - self.openedFiles: Dict[int, FileProxy] = {} - self.openedMappings: Dict[int, FileMapping] = {} - self.fileMap: Dict[str, FileMapping] = {} - self.fileMapPath = self.config.outDir / "mapping.json" + self.mappings: Dict[(int, int), FileMapping] = {} + self.filesystemRoot = self.config.filesystemDir / self.state.sessionID self.currentIORequests: Dict[(int, int), DeviceIORequestPDU] = {} self.forgedRequests: Dict[(int, int), DeviceRedirectionMITM.ForgedRequest] = {} @@ -94,28 +90,18 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye onPDUReceived=self.onServerPDUReceived, ) - try: - with open(self.fileMapPath, "r") as f: - self.fileMap: Dict[str, FileMapping] = json.loads(f.read(), cls=FileMappingDecoder) - except IOError: - self.log.warning("Could not read the RDPDR file mapping at %(path)s. The file may not exist or it may have incorrect permissions. A new mapping will be created.", { - "path": str(self.fileMapPath), - }) - except json.JSONDecodeError: - self.log.error("Failed to decode file mapping, overwriting previous file") + def deviceRoot(self, deviceID: int): + return self.filesystemRoot / f"device{deviceID}" + + def createDeviceRoot(self, deviceID: int): + path = self.deviceRoot(deviceID) + path.mkdir(parents=True, exist_ok=True) + return path @property def config(self): return self.state.config - def saveMapping(self): - """ - Save the file mapping to a file in JSON format. - """ - - with open(self.fileMapPath, "w") as f: - f.write(json.dumps(self.fileMap, cls=FileMappingEncoder, indent=4, sort_keys=True)) - def onClientPDUReceived(self, pdu: DeviceRedirectionPDU): self.statCounter.increment(STAT.DEVICE_REDIRECTION, STAT.DEVICE_REDIRECTION_CLIENT) self.handlePDU(pdu, self.server) @@ -135,7 +121,7 @@ def handlePDU(self, pdu: DeviceRedirectionPDU, destination: DeviceRedirectionLay if isinstance(pdu, DeviceIORequestPDU) and destination is self.client: self.handleIORequest(pdu) elif isinstance(pdu, DeviceIOResponsePDU) and destination is self.server: - dropPDU = pdu.completionID in self.forgedRequests + dropPDU = (pdu.deviceID, pdu.completionID) in self.forgedRequests self.handleIOResponse(pdu) elif isinstance(pdu, DeviceListAnnounceRequest): @@ -198,6 +184,7 @@ def handleDeviceListAnnounceRequest(self, pdu: DeviceListAnnounceRequest): "deviceName": device.preferredDOSName }) + self.createDeviceRoot(device.deviceID) self.observer.onDeviceAnnounce(device) def handleCreateResponse(self, request: DeviceCreateRequestPDU, response: DeviceCreateResponsePDU): @@ -210,23 +197,12 @@ def handleCreateResponse(self, request: DeviceCreateRequestPDU, response: Device """ isFileRead = request.desiredAccess & (FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA) != 0 - isNotDirectory = request.createOptions & CreateOption.FILE_NON_DIRECTORY_FILE != 0 - - if isFileRead and isNotDirectory: - remotePath = Path(request.path) - mapping = FileMapping.generate(remotePath, self.config.fileDir) - proxy = FileProxy(mapping.localPath, "wb") - - key = (response.deviceID, response.completionID, response.fileID) - self.openedFiles[key] = proxy - self.openedMappings[key] = mapping - - proxy.createObserver( - onFileCreated = lambda _: self.log.info("Saving file '%(remotePath)s' to '%(localPath)s'", { - "localPath": mapping.localPath, "remotePath": mapping.remotePath - }), - onFileClosed = lambda _: self.log.debug("Closing file %(path)s", {"path": mapping.localPath}) - ) + isDirectory = request.createOptions & CreateOption.FILE_NON_DIRECTORY_FILE == 0 + + if isFileRead and not isDirectory: + mapping = FileMapping.generate(request.path, self.config.fileDir, self.deviceRoot(response.deviceID), self.log) + key = (response.deviceID, response.fileID) + self.mappings[key] = mapping def handleReadResponse(self, request: DeviceReadRequestPDU, response: DeviceReadResponsePDU): """ @@ -234,66 +210,27 @@ def handleReadResponse(self, request: DeviceReadRequestPDU, response: DeviceRead :param request: the device read request :param response: the device IO response to the request """ - key = (response.deviceID, response.completionID, request.fileID) - - if key in self.openedFiles: - file = self.openedFiles[key] - file.seek(request.offset) - file.write(response.payload) + key = (response.deviceID, request.fileID) - # Save the mapping permanently - mapping = self.openedMappings[key] - fileName = mapping.localPath.name - - if fileName not in self.fileMap: - self.fileMap[fileName] = mapping - self.saveMapping() + if key in self.mappings: + mapping = self.mappings[key] + mapping.seek(request.offset) + mapping.write(response.payload) def handleCloseResponse(self, request: DeviceCloseRequestPDU, response: DeviceCloseResponsePDU): """ Close the file if it was open. Compute the hash of the file, then delete it if we already have a file with the same hash. :param request: the device close request - :param _: the device IO response to the request + :param response: the device IO response to the request """ self.statCounter.increment(STAT.DEVICE_REDIRECTION_FILE_CLOSE) - key = (response.deviceID, response.completionID, request.fileID) + key = (response.deviceID, request.fileID) - if key in self.openedFiles: - file = self.openedFiles.pop(key) - file.close() - - if file.file is None: - return - - currentMapping = self.openedMappings.pop(key) - - # Compute the hash for the final file - with open(currentMapping.localPath, "rb") as f: - sha1 = hashlib.sha1() - - while True: - buffer = f.read(65536) - - if len(buffer) == 0: - break - - sha1.update(buffer) - - currentMapping.hash = sha1.hexdigest() - - # Check if a file with the same hash exists. If so, keep that one and remove the current file. - for localPath, mapping in self.fileMap.items(): - if mapping is currentMapping: - continue - - if mapping.hash == currentMapping.hash: - currentMapping.localPath.unlink() - self.fileMap.pop(currentMapping.localPath.name) - break - - self.saveMapping() + if key in self.mappings: + mapping = self.mappings.pop(key) + mapping.finalize() def handleClientLogin(self): """ @@ -301,7 +238,9 @@ def handleClientLogin(self): """ if self.state.credentialsCandidate or self.state.inputBuffer: - self.log.info("Credentials candidate from heuristic: %(credentials_candidate)s", {"credentials_candidate" : (self.state.credentialsCandidate or self.state.inputBuffer) }) + self.log.info("Credentials candidate from heuristic: %(credentials_candidate)s", { + "credentials_candidate" : (self.state.credentialsCandidate or self.state.inputBuffer) + }) # Deactivate the logger for this client self.state.loggedIn = True @@ -318,7 +257,7 @@ def findNextRequestID(self) -> int: """ completionID = DeviceRedirectionMITM.FORGED_COMPLETION_ID - while completionID in self.forgedRequests: + while completionID in [key[1] for key in self.forgedRequests]: completionID += 1 return completionID @@ -333,7 +272,7 @@ def sendForgedFileRead(self, deviceID: int, path: str) -> int: if not self.config.extractFiles: self.log.info('Ignored attempt to forge file reads because file extraction is disabled.') - return + return 0 self.statCounter.increment(STAT.DEVICE_REDIRECTION_FORGED_FILE_READ) @@ -357,7 +296,7 @@ def sendForgedDirectoryListing(self, deviceID: int, path: str) -> int: if not self.config.extractFiles: self.log.info('Ignored attempt to forge directory listing because file extraction is disabled.') - return + return 0 self.statCounter.increment(STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING) @@ -516,7 +455,7 @@ def send(self): openPath = self.path[: self.path.index("*")] if openPath.endswith("\\"): - openPath = self.path[: -1] + openPath = openPath[: -1] # We need to start by opening the directory. request = DeviceCreateRequestPDU( diff --git a/pyrdp/mitm/FileCrawlerMITM.py b/pyrdp/mitm/FileCrawlerMITM.py index 656287a60..6877a8069 100644 --- a/pyrdp/mitm/FileCrawlerMITM.py +++ b/pyrdp/mitm/FileCrawlerMITM.py @@ -3,18 +3,19 @@ # Copyright (C) 2019 GoSecure Inc. # Licensed under the GPLv3 or later. # +import fnmatch from collections import defaultdict from logging import LoggerAdapter from pathlib import Path -from typing import BinaryIO, Dict, List, Optional, Set +from typing import Dict, List, Optional, Set from pyrdp.enum.virtual_channel.device_redirection import DeviceType -from pyrdp.mitm.config import MITMConfig from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM, DeviceRedirectionMITMObserver +from pyrdp.mitm.FileMapping import FileMapping +from pyrdp.mitm.config import MITMConfig from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import DeviceAnnounce -import fnmatch class VirtualFile: """ @@ -61,12 +62,11 @@ def __init__(self, mainLogger: LoggerAdapter, fileLogger: LoggerAdapter, config: self.deviceRedirection: Optional[DeviceRedirectionMITM] = None # Pending crawler requests - self.fileDownloadRequests: Dict[int, Path] = {} self.directoryListingRequests: Dict[int, Path] = {} self.directoryListingLists = defaultdict(list) # Download management - self.downloadFiles: Dict[str, BinaryIO] = {} + self.fileMappings: Dict[str, FileMapping] = {} self.downloadDirectories: Set[int] = set() # Crawler detection patterns @@ -101,9 +101,6 @@ def preparePatterns(self): Should only be called once. """ - matchPath = None - ignorePath = None - # Get the default file in pyrdp/mitm/crawler_config if self.config.crawlerMatchFileName: matchPath = Path(self.config.crawlerMatchFileName).absolute() @@ -126,7 +123,7 @@ def parsePatterns(self, path: str) -> List[str]: try: with open(path, "r") as f: for line in f: - if line and line[0] in ["#", " ", "\n"]: + if not line or line[0] in ["#", " ", "\n"]: continue patternList.append(line.lower().rstrip()) @@ -136,10 +133,22 @@ def parsePatterns(self, path: str) -> List[str]: return patternList + def onDeviceAnnounce(self, device: DeviceAnnounce): + if device.deviceType == DeviceType.RDPDR_DTYP_FILESYSTEM: + + drive = VirtualFile(device.deviceID, device.preferredDOSName, "/", True) + + self.devices[drive.deviceID] = drive + self.unvisitedDrive.append(drive) + + # If the crawler hasn't started, start one instance + if len(self.devices) == 1: + self.dispatchDownload() + def dispatchDownload(self): """ Processes each queue in order of priority. - File download have priority over directory download. + File downloads have priority over directory downloads. Crawl each folder before visiting another drive. """ @@ -161,15 +170,56 @@ def dispatchDownload(self): # List an unvisited drive elif len(self.unvisitedDrive) != 0: drive = self.unvisitedDrive.pop() - - # TODO : Maybe dump whole drive if there isn't a lot of files? - # Maybe if theres no directory at the root directory -> dump all? self.log.info("Begin crawling disk %(disk)s", {"disk" : drive.name}) self.fileLogger.info("Begin crawling disk %(disk)s", {"disk" : drive.name}) self.listDirectory(drive.deviceID, drive.path) else: self.log.info("Done crawling.") + def listDirectory(self, deviceID: int, path: str, download: bool = False): + """ + List the directory + :param deviceID: Drive we are actually listing. + :param path: Path of the directory we are listing. + :param download: Wether or not we need to download this directory. + """ + listingPath = str(Path(path).absolute()).replace("/", "\\") + + if not listingPath.endswith("*"): + if not listingPath.endswith("\\"): + listingPath += "\\" + + listingPath += "*" + + requestID = self.deviceRedirection.sendForgedDirectoryListing(deviceID, listingPath) + + # If the directory is flagged for download, keep trace of the incoming request to trigger download. + if download: + self.downloadDirectories.add(requestID) + + self.directoryListingRequests[requestID] = Path(path).absolute() + + def onDirectoryListingResult(self, deviceID: int, requestID: int, fileName: str, isDirectory: bool): + if requestID not in self.directoryListingRequests: + return + + path = self.directoryListingRequests[requestID] + filePath = path / fileName + + file = VirtualFile(deviceID, fileName, str(filePath), isDirectory) + directoryList = self.directoryListingLists[requestID] + directoryList.append(file) + + def onDirectoryListingComplete(self, deviceID: int, requestID: int): + self.directoryListingRequests.pop(requestID, {}) + + # If directory was flagged for download + if requestID in self.downloadDirectories: + self.downloadDirectories.remove(requestID) + self.addListingToDownloadQueue(requestID) + else: + self.crawlListing(requestID) + def addListingToDownloadQueue(self, requestID: int): directoryList = self.directoryListingLists.pop(requestID, {}) @@ -181,6 +231,7 @@ def addListingToDownloadQueue(self, requestID: int): self.matchedDirectoryQueue.append(item) else: self.matchedFileQueue.append(item) + self.dispatchDownload() def crawlListing(self, requestID: int): @@ -211,74 +262,34 @@ def crawlListing(self, requestID: int): if matched: self.matchedFileQueue.append(item) - self.fileLogger.info("%(file)s - %(isDirectory)s - %(isDownloaded)s", {"file" : item.path, "isDirectory": item.isDirectory, "isDownloaded": matched}) + self.fileLogger.info("%(file)s - %(isDirectory)s - %(isMatched)s", { + "file" : item.path, + "isDirectory": item.isDirectory, + "isMatched": matched + }) + self.dispatchDownload() def downloadFile(self, file: VirtualFile): remotePath = file.path - basePath = f"{self.config.fileDir}/{self.log.sessionID}" - localPath = f"{basePath}{remotePath}" - - self.log.info("Saving %(remotePath)s to %(localPath)s", {"remotePath": remotePath, "localPath": localPath}) - - try: - # Create parent directory, don't raise error if it already exists - Path(localPath).parent.mkdir(parents=True, exist_ok=True) - targetFile = open(localPath, "wb") - except Exception as e: - self.log.exception(e) - self.log.error("Cannot save file: %(localPath)s", {"localPath": localPath}) - return - - self.downloadFiles[remotePath] = targetFile + mapping = FileMapping.generate( + remotePath, + self.config.fileDir, + self.deviceRedirection.createDeviceRoot(file.deviceID), + self.log + ) + + self.fileMappings[remotePath] = mapping self.deviceRedirection.sendForgedFileRead(file.deviceID, remotePath) - def listDirectory(self, deviceID: int, path: str, download: bool = False): - """ - List the directory - :param deviceID: Drive we are actually listing. - :param path: Path of the directory we are listing. - :param download: Wether or not we need to download this directory. - """ - listingPath = str(Path(path).absolute()).replace("/", "\\") - - if not listingPath.endswith("*"): - if not listingPath.endswith("\\"): - listingPath += "\\" - - listingPath += "*" - - requestID = self.deviceRedirection.sendForgedDirectoryListing(deviceID, listingPath) - - # If the directory is flagged for download, keep trace of the incoming request to trigger download. - if download: - self.downloadDirectories.add(requestID) - - self.directoryListingRequests[requestID] = Path(path).absolute() - - def onDeviceAnnounce(self, device: DeviceAnnounce): - if device.deviceType == DeviceType.RDPDR_DTYP_FILESYSTEM: - - drive = VirtualFile(device.deviceID, device.preferredDOSName, "/", True) - - self.devices[drive.deviceID] = drive - self.unvisitedDrive.append(drive) - - # If the crawler hasn't started, start one instance - if len(self.devices) == 1: - self.dispatchDownload() - def onFileDownloadResult(self, deviceID: int, requestID: int, path: str, offset: int, data: bytes): - remotePath = path.replace("\\", "/") - - targetFile = self.downloadFiles[remotePath] - targetFile.write(data) + mapping = self.fileMappings[path] + mapping.seek(offset) + mapping.write(data) def onFileDownloadComplete(self, deviceID: int, requestID: int, path: str, errorCode: int): - remotePath = path.replace("\\", "/") - - file = self.downloadFiles.pop(remotePath) - file.close() + mapping = self.fileMappings.pop(path) + mapping.finalize() if errorCode != 0: # TODO : Handle common error codes like : @@ -286,29 +297,8 @@ def onFileDownloadComplete(self, deviceID: int, requestID: int, path: str, error # Doc : https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-erref/18d8fbe8-a967-4f1c-ae50-99ca8e491d2d self.log.error("Error happened when downloading %(remotePath)s. The file may not have been saved completely. Error code: %(errorCode)s", { - "remotePath": remotePath, + "remotePath": path, "errorCode": "0x%08lx" % errorCode, }) self.dispatchDownload() - - def onDirectoryListingResult(self, deviceID: int, requestID: int, fileName: str, isDirectory: bool): - if requestID not in self.directoryListingRequests: - return - - path = self.directoryListingRequests[requestID] - filePath = path / fileName - - file = VirtualFile(deviceID, fileName, str(filePath), isDirectory) - directoryList = self.directoryListingLists[requestID] - directoryList.append(file) - - def onDirectoryListingComplete(self, deviceID: int, requestID: int): - self.directoryListingRequests.pop(requestID, {}) - - # If directory was flagged for download - if requestID in self.downloadDirectories: - self.downloadDirectories.remove(requestID) - self.addListingToDownloadQueue(requestID) - else: - self.crawlListing(requestID) diff --git a/pyrdp/mitm/FileMapping.py b/pyrdp/mitm/FileMapping.py index a265d39e3..8d0c1baaf 100644 --- a/pyrdp/mitm/FileMapping.py +++ b/pyrdp/mitm/FileMapping.py @@ -4,12 +4,11 @@ # Licensed under the GPLv3 or later. # -import datetime -import json +import hashlib +import tempfile +from logging import LoggerAdapter from pathlib import Path -from typing import Dict - -import names +from typing import io class FileMapping: @@ -18,69 +17,82 @@ class FileMapping: transferred over RDP. """ - def __init__(self, remotePath: Path, localPath: Path, creationTime: datetime.datetime, fileHash: str): + def __init__(self, file: io.BinaryIO, dataPath: Path, filesystemPath: Path, filesystemDir: Path, log: LoggerAdapter): """ - :param remotePath: the path of the file on the original machine - :param localPath: the path of the file on the intercepting machine - :param creationTime: the creation time of the local file - :param fileHash: the file hash in hex format (empty string if the file is not complete) + :param file: the file handle for dataPath + :param dataPath: path where the file is actually saved + :param filesystemPath: the path to the replicated filesystem, which will be symlinked to dataPath + :param log: logger """ - self.remotePath = remotePath - self.localPath = localPath - self.creationTime = creationTime - self.hash: str = fileHash + self.file = file + self.filesystemPath = filesystemPath + self.dataPath = dataPath + self.filesystemDir = filesystemDir + self.log = log + self.written = False - @staticmethod - def generate(remotePath: Path, outDir: Path): - localName = f"{names.get_first_name()}{names.get_last_name()}" - creationTime = datetime.datetime.now() + def seek(self, offset: int): + self.file.seek(offset) - index = 2 - suffix = "" + def write(self, data: bytes): + self.file.write(data) + self.written = True - while True: - if not (outDir / f"{localName}{suffix}").exists(): - break - else: - suffix = f"_{index}" - index += 1 + def getHash(self): + with open(self.dataPath, "rb") as f: + sha1 = hashlib.sha1() - localName += suffix + while True: + buffer = f.read(65536) - return FileMapping(remotePath, outDir / localName, creationTime, "") + if len(buffer) == 0: + break + sha1.update(buffer) -class FileMappingEncoder(json.JSONEncoder): - """ - JSON encoder for FileMapping objects. - """ + return sha1.hexdigest() - def default(self, o): - if isinstance(o, datetime.datetime): - return o.isoformat() - elif not isinstance(o, FileMapping): - return super().default(o) + def finalize(self): + self.log.debug("Closing file %(path)s", {"path": self.dataPath}) + self.file.close() - return { - "remotePath": str(o.remotePath), - "localPath": str(o.localPath), - "creationTime": o.creationTime, - "sha1": o.hash - } + fileHash = self.getHash() + # Go up one directory because files are saved to outDir / tmp while we're downloading them + hashPath = (self.dataPath.parents[1] / fileHash) -class FileMappingDecoder(json.JSONDecoder): - """ - JSON decoder for FileMapping objects. - """ + # Don't keep the file if we haven't written anything to it or it's a duplicate, otherwise rename and move to files dir + if not self.written or hashPath.exists(): + self.dataPath.unlink() + else: + self.dataPath = self.dataPath.rename(hashPath) + + # Whether it's a duplicate or a new file, we need to create a link to it in the filesystem clone + if self.written: + self.filesystemPath.parents[0].mkdir(exist_ok=True) + + if self.filesystemPath.exists(): + self.filesystemPath.unlink() + + self.filesystemPath.symlink_to(hashPath) + + self.log.info("SHA1 '%(path)s' = '%(hash)s'", { + "path": self.filesystemPath.relative_to(self.filesystemDir), "hash": fileHash + }) + + @staticmethod + def generate(remotePath: str, outDir: Path, filesystemDir: Path, log: LoggerAdapter): + remotePath = Path(remotePath.replace("\\", "/")) + filesystemPath = filesystemDir / remotePath.relative_to("/") + + tmpOutDir = outDir / "tmp" + tmpOutDir.mkdir(exist_ok=True) - def __init__(self): - super().__init__(object_hook=self.decodeFileMapping) + handle, tmpPath = tempfile.mkstemp("", "", tmpOutDir) + file = open(handle, "wb") - def decodeFileMapping(self, dct: Dict): - for key in ["remotePath", "localPath", "creationTime"]: - if key not in dct: - return dct + log.info("Saving file '%(remotePath)s' to '%(localPath)s'", { + "localPath": tmpPath, "remotePath": remotePath + }) - creationTime = datetime.datetime.strptime(dct["creationTime"], "%Y-%m-%dT%H:%M:%S.%f") - return FileMapping(Path(dct["remotePath"]), Path(dct["localPath"]), creationTime, dct["sha1"]) \ No newline at end of file + return FileMapping(file, Path(tmpPath), filesystemPath, filesystemDir, log) diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 5df6a13e8..9babc199f 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -85,7 +85,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf self.statCounter = StatCounter() """Class to keep track of connection-related statistics such as # of mouse events, # of output events, etc.""" - self.state = state if state is not None else RDPMITMState(self.config) + self.state = state if state is not None else RDPMITMState(self.config, self.log.sessionID) """The MITM state""" self.client = RDPLayerSet() @@ -152,7 +152,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf replayFileName = "rdp_replay_{}_{}_{}.pyrdp"\ .format(date.strftime('%Y%m%d_%H-%M-%S'), date.microsecond // 1000, - self.log.sessionID) + self.state.sessionID) self.recorder.setRecordFilename(replayFileName) self.recorder.addTransport(FileLayer(self.config.replayDir / replayFileName)) @@ -339,10 +339,10 @@ def buildClipboardChannel(self, client: MCSServerChannel, server: MCSClientChann if self.config.disableActiveClipboardStealing: mitm = PassiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), - self.recorder, self.statCounter) + self.recorder, self.statCounter, self.state) else: mitm = ActiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), - self.recorder, self.statCounter) + self.recorder, self.statCounter, self.state) self.channelMITMs[client.channelID] = mitm def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel): diff --git a/pyrdp/mitm/config.py b/pyrdp/mitm/config.py index a88d28fd2..f12a5cf4b 100644 --- a/pyrdp/mitm/config.py +++ b/pyrdp/mitm/config.py @@ -96,6 +96,13 @@ def fileDir(self) -> Path: """ return self.outDir / "files" + @property + def filesystemDir(self) -> Path: + """ + Get the directory for filesystem clones. + """ + return self.outDir / "filesystems" + @property def certDir(self) -> Path: """ diff --git a/pyrdp/mitm/state.py b/pyrdp/mitm/state.py index e93bf9aa0..35998994a 100644 --- a/pyrdp/mitm/state.py +++ b/pyrdp/mitm/state.py @@ -21,7 +21,7 @@ class RDPMITMState: State object for the RDP MITM. This is for data that needs to be shared across components. """ - def __init__(self, config: MITMConfig): + def __init__(self, config: MITMConfig, sessionID: str): self.requestedProtocols: Optional[NegotiationProtocols] = None """The original request protocols""" @@ -73,6 +73,9 @@ def __init__(self, config: MITMConfig): self.ctrlPressed = False """The current keybaord ctrl state""" + self.sessionID = sessionID + """The current session ID""" + self.securitySettings.addObserver(self.crypters[ParserMode.CLIENT]) self.securitySettings.addObserver(self.crypters[ParserMode.SERVER]) diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py new file mode 100644 index 000000000..cd347de47 --- /dev/null +++ b/test/test_DeviceRedirectionMITM.py @@ -0,0 +1,358 @@ +import unittest +from pathlib import Path +from unittest.mock import Mock, MagicMock, patch + +from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity, DeviceRedirectionPacketID, MajorFunction, \ + MinorFunction +from pyrdp.logging.StatCounter import StatCounter, STAT +from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM +from pyrdp.pdu import DeviceIOResponsePDU, DeviceRedirectionPDU + + +def MockIOError(): + ioError = Mock(deviceID = 0, completionID = 0, ioStatus = IOOperationSeverity.STATUS_SEVERITY_ERROR << 30) + return ioError + + +class DeviceRedirectionMITMTest(unittest.TestCase): + def setUp(self): + self.client = Mock() + self.server = Mock() + self.log = Mock() + self.statCounter = Mock() + self.state = Mock() + self.state.config = MagicMock() + self.state.config.outDir = Path("/tmp") + self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state) + + @patch("pyrdp.mitm.FileMapping.FileMapping.generate") + def sendCreateResponse(self, request, response, generate): + self.mitm.handleCreateResponse(request, response) + return generate + + def test_stats(self): + self.mitm.handlePDU = Mock() + self.mitm.statCounter = StatCounter() + + self.mitm.onClientPDUReceived(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION], 1) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_CLIENT], 1) + + self.mitm.onServerPDUReceived(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION], 2) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_SERVER], 1) + + self.mitm.handleIORequest(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IOREQUEST], 1) + + self.mitm.handleIOResponse(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IORESPONSE], 1) + + error = MockIOError() + self.mitm.handleIORequest(error) + self.mitm.handleIOResponse(error) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IOERROR], 1) + + self.mitm.handleCloseResponse(Mock(), Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FILE_CLOSE], 1) + + self.mitm.sendForgedFileRead(Mock(), Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FORGED_FILE_READ], 1) + + self.mitm.sendForgedDirectoryListing(Mock(), MagicMock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING], 1) + + def test_ioError_showsWarning(self): + self.log.warning = Mock() + error = MockIOError() + + self.mitm.handleIORequest(error) + self.mitm.handleIOResponse(error) + self.log.warning.assert_called_once() + + def test_deviceListAnnounce_logsDevices(self): + pdu = Mock() + pdu.deviceList = [Mock(), Mock(), Mock()] + + self.mitm.observer = Mock() + self.mitm.handleDeviceListAnnounceRequest(pdu) + + self.assertEqual(self.log.info.call_count, len(pdu.deviceList)) + self.assertEqual(self.mitm.observer.onDeviceAnnounce.call_count, len(pdu.deviceList)) + + def test_handleClientLogin_logsCredentials(self): + creds = "PASSWORD" + self.log.info = Mock() + + self.state.credentialsCandidate = creds + self.state.inputBuffer = "" + self.mitm.handleClientLogin() + self.log.info.assert_called_once() + self.assertTrue(creds in self.log.info.call_args[0][1].values()) + + self.log.info.reset_mock() + self.state.credentialsCandidate = "" + self.state.inputBuffer = creds + self.mitm.handleClientLogin() + self.log.info.assert_called_once() + self.assertTrue(creds in self.log.info.call_args[0][1].values()) + + self.mitm.handleClientLogin = Mock() + pdu = Mock(packetID = DeviceRedirectionPacketID.PAKID_CORE_USER_LOGGEDON) + pdu.__class__ = DeviceRedirectionPDU + + self.mitm.handlePDU(pdu, self.client) + self.mitm.handleClientLogin.assert_called_once() + + def test_handleIOResponse_uniqueResponse(self): + handler = Mock() + self.mitm.responseHandlers[1234] = handler + + pdu = Mock(deviceID = 0, completionID = 0, majorFunction = 1234, ioStatus = 0) + self.mitm.handleIORequest(pdu) + self.mitm.handleIOResponse(pdu) + handler.assert_called_once() + + # Second response should not go through + self.mitm.handleIOResponse(pdu) + handler.assert_called_once() + + def test_handleIOResponse_matchingOnly(self): + handler = Mock() + self.mitm.responseHandlers[1234] = handler + + request = Mock(deviceID = 0, completionID = 0) + matching_response = Mock(deviceID = 0, completionID = 0, majorFunction = 1234, ioStatus = 0) + bad_completionID = Mock(deviceID = 0, completionID = 1, majorFunction = 1234, ioStatus = 0) + bad_deviceID = Mock(deviceID = 1, completionID = 0, majorFunction = 1234, ioStatus = 0) + + self.mitm.handleIORequest(request) + self.mitm.handleIOResponse(matching_response) + handler.assert_called_once() + + self.mitm.handleIORequest(request) + + self.mitm.handleIOResponse(bad_completionID) + handler.assert_called_once() + self.log.error.assert_called_once() + self.log.error.reset_mock() + + self.mitm.handleIOResponse(bad_deviceID) + handler.assert_called_once() + self.log.error.assert_called_once() + self.log.error.reset_mock() + + def test_handlePDU_hidesForgedResponses(self): + majorFunction = MajorFunction.IRP_MJ_CREATE + handler = Mock() + completionID = self.mitm.sendForgedFileRead(0, "forged") + request = self.mitm.forgedRequests[(0, completionID)] + request.handlers[majorFunction] = handler + + self.assertEqual(len(self.mitm.forgedRequests), 1) + response = Mock(deviceID = 0, completionID = completionID, majorFunction = majorFunction, ioStatus = 0) + response.__class__ = DeviceIOResponsePDU + self.mitm.handlePDU(response, self.mitm.server) + handler.assert_called_once() + self.mitm.server.sendPDU.assert_not_called() + + def test_handleCreateResponse_createsMapping(self): + createRequest = Mock( + deviceID = 0, + completionID = 0, + desiredAccess = (FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions = CreateOption.FILE_NON_DIRECTORY_FILE, + path = "file", + ) + createResponse = Mock(deviceID = 0, completionID = 0, fileID = 0) + + generate = self.sendCreateResponse(createRequest, createResponse) + self.assertEqual(len(self.mitm.mappings), 1) + generate.assert_called_once() + + def test_handleReadResponse_writesData(self): + request = Mock( + deviceID = 0, + completionID = 0, + fileID = 0, + desiredAccess = (FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions = CreateOption.FILE_NON_DIRECTORY_FILE, + path = "file", + ) + response = Mock(deviceID = 0, completionID = 0, fileID = 0, payload = "test payload") + self.mitm.saveMapping = Mock() + + self.sendCreateResponse(request, response) + mapping = list(self.mitm.mappings.values())[0] + mapping.write = Mock() + + self.mitm.handleReadResponse(request, response) + mapping.write.assert_called_once() + + # Make sure it checks the file ID + request.fileID, response.fileID = 1, 1 + self.mitm.handleReadResponse(request, response) + mapping.write.assert_called_once() + + def test_handleCloseResponse_finalizesMapping(self): + request = Mock( + deviceID=0, + completionID=0, + fileID=0, + desiredAccess=(FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions=CreateOption.FILE_NON_DIRECTORY_FILE, + path="file", + ) + response = Mock(deviceID=0, completionID=0, fileID=0, payload="test payload") + self.mitm.saveMapping = Mock() + + self.sendCreateResponse(request, response) + mapping = list(self.mitm.mappings.values())[0] + mapping.finalize = Mock() + + self.mitm.handleCloseResponse(request, response) + + mapping.finalize.assert_called_once() + + def test_findNextRequestID_incrementsRequestID(self): + baseID = self.mitm.findNextRequestID() + self.mitm.sendForgedFileRead(0, Mock()) + self.assertEqual(self.mitm.findNextRequestID(), baseID + 1) + self.mitm.sendForgedFileRead(1, Mock()) + self.assertEqual(self.mitm.findNextRequestID(), baseID + 2) + + def test_sendForgedFileRead_failsWhenDisabled(self): + self.mitm.config.extractFiles = False + self.assertFalse(self.mitm.sendForgedFileRead(1, "/test")) + + def test_sendForgedDirectoryListing_failsWhenDisabled(self): + self.mitm.config.extractFiles = False + self.assertFalse(self.mitm.sendForgedDirectoryListing(1, "/")) + + +class ForgedRequestTest(unittest.TestCase): + def setUp(self): + self.request = DeviceRedirectionMITM.ForgedRequest(0, 0, Mock()) + + def test_sendIORequest_sendsToClient(self): + self.request.sendIORequest(Mock()) + self.request.mitm.client.sendPDU.assert_called_once() + + def test_onCloseResponse_completesRequest(self): + self.request.onCloseResponse(Mock()) + self.assertTrue(self.request.isComplete) + + def test_onCreateResponse_checksStatus(self): + self.request.onCreateResponse(Mock(ioStatus = 1)) + self.assertIsNone(self.request.fileID) + + +class ForgedFileReadRequestTest(unittest.TestCase): + def setUp(self): + self.request = DeviceRedirectionMITM.ForgedFileReadRequest(0, 0, Mock(), "file") + + def test_onCreateResponse_sendsReadRequest(self): + self.request.sendReadRequest = Mock() + self.request.onCreateResponse(Mock(ioStatus = 0)) + self.request.sendReadRequest.assert_called_once() + + def test_onCreateResponse_completesRequest(self): + self.request.onCreateResponse(Mock(ioStatus = 1)) + self.request.mitm.observer.onFileDownloadComplete.assert_called_once() + self.assertTrue(self.request.isComplete) + + def test_handleFileComplete_sendsCloseRequest(self): + self.request.sendCloseRequest = Mock() + self.request.fileID = Mock() + self.request.handleFileComplete(1) + self.request.sendCloseRequest.assert_called_once() + + def test_onReadResponse_closesOnError(self): + self.request.fileID = Mock() + self.request.sendCloseRequest = Mock() + self.request.mitm.observer.onFileDownloadComplete = Mock() + self.request.onReadResponse(Mock(ioStatus = 1)) + self.request.sendCloseRequest.assert_called_once() + self.request.mitm.observer.onFileDownloadComplete.assert_called_once() + + def test_onReadResponse_updatesProgress(self): + payload = b"testing" + self.request.sendReadRequest = Mock() + self.request.mitm.observer.onFileDownloadResult = Mock() + self.request.onReadResponse(Mock(ioStatus = 0, payload = payload)) + + self.assertEqual(self.request.offset, len(payload)) + self.request.mitm.observer.onFileDownloadResult.assert_called_once() + self.request.sendReadRequest.assert_called_once() + + def test_onReadResponse_closesWhenDone(self): + self.request.fileID = Mock() + self.request.sendCloseRequest = Mock() + self.request.mitm.observer.onFileDownloadComplete = Mock() + self.request.onReadResponse(Mock(ioStatus = 0, payload = b"")) + self.request.sendCloseRequest.assert_called_once() + self.request.mitm.observer.onFileDownloadComplete.assert_called_once() + + +class ForgedDirectoryListingRequestTest(unittest.TestCase): + def setUp(self): + self.request = DeviceRedirectionMITM.ForgedDirectoryListingRequest(0, 0, Mock(), "directory") + + def test_send_removesTrailingSlash(self): + self.request.sendIORequest = Mock() + self.request.path = "directory\\" + + self.request.send() + ioRequest = self.request.sendIORequest.call_args[0][0] + self.assertEqual(ioRequest.path, "directory") + + def test_send_handlesWildcard(self): + self.request.sendIORequest = Mock() + self.request.path = "directory\\*" + + self.request.send() + ioRequest = self.request.sendIORequest.call_args[0][0] + self.assertEqual(ioRequest.path, "directory") + + def test_send_handlesNormalPath(self): + self.request.sendIORequest = Mock() + self.request.send() + + ioRequest = self.request.sendIORequest.call_args[0][0] + self.request.sendIORequest.assert_called_once() + self.assertEqual(ioRequest.path, "directory") + + def test_onCreateResponse_completesOnError(self): + self.request.onCreateResponse(Mock(ioStatus = 1)) + self.assertTrue(self.request.isComplete) + + def test_onCreateResponse_sendsDirectoryRequest(self): + self.request.sendIORequest = Mock() + self.request.onCreateResponse(Mock(ioStatus = 0)) + self.request.sendIORequest.assert_called_once() + self.assertEqual(self.request.sendIORequest.call_args[0][0].majorFunction, MajorFunction.IRP_MJ_DIRECTORY_CONTROL) + self.assertEqual(self.request.sendIORequest.call_args[0][0].minorFunction, MinorFunction.IRP_MN_QUERY_DIRECTORY) + + def test_onDirectoryControlResponse_completesOnError(self): + self.request.sendIORequest = Mock() + self.request.onDirectoryControlResponse(Mock(ioStatus = 1, minorFunction = MinorFunction.IRP_MN_QUERY_DIRECTORY)) + self.request.sendIORequest.assert_called_once() + self.assertEqual(self.request.sendIORequest.call_args[0][0].majorFunction, MajorFunction.IRP_MJ_CLOSE) + self.request.mitm.observer.onDirectoryListingComplete.assert_called_once() + + def test_onDirectoryControlResponse_handlesSuccessfulResponse(self): + self.request.sendIORequest = Mock() + response = MagicMock( + ioStatus = 0, + minorFunction = MinorFunction.IRP_MN_QUERY_DIRECTORY, + fileInformation = [MagicMock()] + ) + + self.request.onDirectoryControlResponse(response) + + # Sends result to observer + self.request.mitm.observer.onDirectoryListingResult.assert_called_once() + + # Sends follow-up directory listing request + self.assertEqual(self.request.sendIORequest.call_args[0][0].majorFunction, MajorFunction.IRP_MJ_DIRECTORY_CONTROL) + self.assertEqual(self.request.sendIORequest.call_args[0][0].minorFunction, MinorFunction.IRP_MN_QUERY_DIRECTORY) diff --git a/test/test_FileMapping.py b/test/test_FileMapping.py new file mode 100644 index 000000000..db385f846 --- /dev/null +++ b/test/test_FileMapping.py @@ -0,0 +1,85 @@ +import unittest +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock, mock_open + +from pyrdp.mitm.FileMapping import FileMapping + + +class FileMappingTest(unittest.TestCase): + def setUp(self): + self.log = Mock() + self.outDir = Path("test/") + self.hash = "testHash" + + @patch("builtins.open", new_callable=mock_open) + @patch("tempfile.mkstemp") + @patch("pathlib.Path.mkdir") + def createMapping(self, mkdir: MagicMock, mkstemp: MagicMock, mock_open_object): + mkstemp.return_value = (1, str(self.outDir / "tmp" / "tmp_test")) + mapping = FileMapping.generate("/test", self.outDir, Path("filesystems"), self.log) + mapping.getHash = Mock(return_value = self.hash) + return mapping, mkdir, mkstemp, mock_open_object + + def test_generate_createsTempFile(self): + mapping, mkdir, mkstemp, mock_open_object = self.createMapping() + mkstemp.return_value = (1, str(self.outDir / "tmp" / "tmp_test")) + + mkdir.assert_called_once_with(exist_ok = True) + mkstemp.assert_called_once() + mock_open_object.assert_called_once() + + tmpDir = mkstemp.call_args[0][-1] + self.assertEqual(tmpDir, self.outDir / "tmp") + + def test_write_setsWritten(self): + mapping, *_ = self.createMapping() + self.assertFalse(mapping.written) + mapping.write(b"data") + self.assertTrue(mapping.written) + + def test_finalize_removesUnwrittenFiles(self): + mapping, *_ = self.createMapping() + + with patch("pathlib.Path.unlink", autospec=True) as mock_unlink: + mapping.finalize() + self.assertTrue(any(args[0][0] == mapping.dataPath for args in mock_unlink.call_args_list)) + + @patch("pathlib.Path.exists", new_callable=lambda: Mock(return_value=True)) + @patch("pathlib.Path.symlink_to") + @patch("pathlib.Path.mkdir") + def test_finalize_removesDuplicates(self, *_): + mapping, *_ = self.createMapping() + mapping.write(b"data") + + with patch("pathlib.Path.unlink", autospec=True) as mock_unlink: + mapping.finalize() + self.assertTrue(any(args[0][0] == mapping.dataPath for args in mock_unlink.call_args_list)) + + @patch("pathlib.Path.unlink") + @patch("pathlib.Path.exists", new_callable=lambda: Mock(return_value=False)) + @patch("pathlib.Path.symlink_to") + @patch("pathlib.Path.mkdir") + def test_finalize_movesFileToOutDir(self, *_): + mapping, *_ = self.createMapping() + mapping.write(b"data") + + with patch("pathlib.Path.rename") as mock_rename: + mapping.finalize() + mock_rename.assert_called_once() + self.assertEqual(mock_rename.call_args[0][0].parents[0], self.outDir) + + @patch("pathlib.Path.rename") + @patch("pathlib.Path.unlink") + @patch("pathlib.Path.exists", new_callable=lambda: Mock(return_value=False)) + def test_finalize_createsSymlink(self, *_): + mapping, *_ = self.createMapping() + mapping.write(b"data") + + with patch("pathlib.Path.symlink_to") as mock_symlink_to, patch("pathlib.Path.mkdir", autospec=True) as mock_mkdir: + mapping.finalize() + + mock_mkdir.assert_called_once() + mock_symlink_to.assert_called_once() + + self.assertEqual(mock_mkdir.call_args[0][0], mapping.filesystemPath.parents[0]) + self.assertEqual(mock_symlink_to.call_args[0][0], self.outDir / self.hash) diff --git a/test/test_prerecorded.py b/test/test_prerecorded.py index a0e745aed..0a30e9b46 100644 --- a/test/test_prerecorded.py +++ b/test/test_prerecorded.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): config.outDir = output_directory # replay_transport = FileLayer(output_path) - state = RDPMITMState(config) + state = RDPMITMState(config, log.sessionID) super().__init__(log, log, config, state, CustomMITMRecorder([], state)) self.client.tcp.sendBytes = sendBytesStub