diff --git a/pypeman/channels.py b/pypeman/channels.py index 8a09313..1ee7bb2 100644 --- a/pypeman/channels.py +++ b/pypeman/channels.py @@ -930,3 +930,4 @@ async def watch_for_file(self): wrap.add_lazy('pypeman.contrib.http', 'HttpChannel', ['aiohttp']) wrap.add_lazy('pypeman.contrib.time', 'CronChannel', ['aiocron']) wrap.add_lazy('pypeman.contrib.ftp', 'FTPWatcherChannel', []) +wrap.add_lazy('pypeman.contrib.sftp', 'SFTPWatcherChannel', []) diff --git a/pypeman/contrib/sftp.py b/pypeman/contrib/sftp.py new file mode 100644 index 0000000..31606af --- /dev/null +++ b/pypeman/contrib/sftp.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python + +import asyncio +import logging +import os +import re + +from pathlib import Path + +import asyncssh + +from pypeman import channels, nodes, message, persistence + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) +logging.getLogger("asyncssh").setLevel(logging.WARNING) # asyncssh's logs are really verbose + +SFTP_TIMEOUT = 10 # timeout in seconds + + +class SFTPConnection(): + """ + FTP connection manager. + """ + def __init__(self, host, port=22, credentials=None, hostkey=None): + """ + :param host: SFTP host. + :param port: SFTP port + :param credentials: A tuple with (login, password) + :param hostkey: Key to login with (optional if auth with login/pwd only) + :return: + """ + self.conn_args = { + "host": host, + "port": port, + "connect_timeout": SFTP_TIMEOUT, + "login_timeout": SFTP_TIMEOUT, + "known_hosts": None, # TODO: change that + } + + if not (credentials or hostkey): + raise TypeError("must have credentials or hostkey param") + if credentials: + self.conn_args["username"] = credentials[0] + self.conn_args["password"] = credentials[1] + if hostkey: + self.conn_args["client_keys"] = [hostkey] + + async def __aenter__(self): + self.client = await asyncssh.connect( + **self.conn_args + ) + self.sftp = await self.client.start_sftp_client() + return self.sftp + + async def __aexit__(self, exc_type, exc_value, tb): + self.sftp.exit() + self.client.close() + + +class SFTPHelper(): + """ + SFTP helper to abstract sftp access. + """ + + def __init__(self, host, port=22, credentials=None, hostkey=None): + self.host = host + self.port = port + self.hostkey = hostkey + self.credentials = credentials + + async def list_dir(self, path): + async with SFTPConnection(host=self.host, port=self.port, + credentials=self.credentials, + hostkey=self.hostkey) as sftp_conn: + listdir = await sftp_conn.readdir(path) + # TODO: test sort + sortedlistdir = sorted(listdir, key=lambda sshfile: sshfile.attrs.mtime) + return sortedlistdir + + async def download_file(self, filepath, encoding="utf-8"): + """ + Download a file from sftp asynchronously. + :param filepath: file path to download. + :return: content of the file. + """ + # Get file content from sftp + logger.debug("SFTP download file from fpath %s", filepath) + async with SFTPConnection(host=self.host, port=self.port, + credentials=self.credentials, + hostkey=self.hostkey) as sftp_conn: + async with sftp_conn.open(filepath, asyncssh.FXF_READ, encoding=encoding) as fin: + content = await fin.read() + return content + + async def upload_file(self, filepath, content, encoding="utf-8"): + """ + Upload a file to sftp. + :param filepath: Path of file to create. + :param content: Content to upload. + """ + # write file in sftp + logger.debug("SFTP upload file to fpath %s", filepath) + async with SFTPConnection(host=self.host, port=self.port, + credentials=self.credentials, + hostkey=self.hostkey) as sftp_conn: + async with sftp_conn.open(filepath, asyncssh.FXF_WRITE, encoding=encoding) as fout: + content = await fout.write(content) + return content + + async def rename(self, fromfilepath, tofilepath): + """ + Rename a file from path to another path in ftp. + :param fromfilepath: original file to rename. + :param tofilepath: destination file. + """ + logger.debug("SFTP rename file %s to %s", fromfilepath, tofilepath) + async with SFTPConnection(host=self.host, port=self.port, + credentials=self.credentials, + hostkey=self.hostkey) as sftp_conn: + await sftp_conn.rename(fromfilepath, tofilepath) + + async def delete(self, filepath): + """ + Delete an SFTP file. + :param filepath: File to delete. + """ + logger.debug("SFTP Delete file %s", filepath) + async with SFTPConnection(host=self.host, port=self.port, + credentials=self.credentials, + hostkey=self.hostkey) as sftp_conn: + await sftp_conn.remove(filepath) + + async def file_exists(self, filepath): + """ + Return True if a file exists in the SFTP, False otherwise + :param filepath: File to check existence. + """ + async with SFTPConnection(host=self.host, port=self.port, + credentials=self.credentials, + hostkey=self.hostkey) as sftp_conn: + if await sftp_conn.exists(filepath): + return True + else: + return False + + +class SFTPWatcherChannel(channels.BaseChannel): + """ + Channel that watch sftp for file creation. + """ + + PERSISTENCE_TABLENAME = "sftpwatcher" + + def __init__(self, *args, host, port=22, credentials=None, hostkey=None, basedir="", + regex='.*', interval=6, delete_after=False, encoding="utf-8", + real_extensions=None, **kwargs): + super().__init__(*args, **kwargs) + + self.basedir = basedir + self.interval = interval + self.delete_after = delete_after + self.re = re.compile(regex) + self.encoding = encoding # If set to None, read as bytes + + self.LAST_READ_MTIME_FIELDNAME = f"{self.name}_last_read_mtime" + self.real_extensions = real_extensions # list of extensions for exemple: [".csv", ".CSV"] + + self.sftphelper = SFTPHelper( + host=host, port=port, credentials=credentials, hostkey=hostkey) + self.backend = None + self.last_read_mtime = 0 + + async def get_last_read_mtime(self): + """ + Get last read mtime in the persistence backend + !caution! the channel must be started + """ + last_read_mtime = await self.backend.get( + self.PERSISTENCE_TABLENAME, self.LAST_READ_MTIME_FIELDNAME, default=0) + return last_read_mtime + + async def set_last_read_mtime(self, mtime_value): + """ + Set last read mtime in the persistence backend + !caution! the channel must be started + """ + await self.backend.store( + self.PERSISTENCE_TABLENAME, self.LAST_READ_MTIME_FIELDNAME, mtime_value) + self.last_read_mtime = mtime_value + return mtime_value + + async def start(self): + if not self.backend: + self.backend = await persistence.get_backend(loop=self.loop) + self.last_read_mtime = await self.get_last_read_mtime() + logger.debug("last_read_mtime at start is %s", str(self.last_read_mtime)) + await super().start() + asyncio.create_task(self.watch_for_file()) + + async def download_file(self, filename): + """ + Download a file from sftp asynchronously. + + :param filepath: file path to download. + + :return: Content of the downloaded file. + """ + if not self.is_stopped(): + logger.debug("file %s found in %s", filename, self.basedir) + return await self.sftphelper.download_file( + self.basedir + '/' + filename, encoding=self.encoding) + + async def get_file_and_process(self, filename): + """ + Download a file from sftp and launch channel processing on msg with result as payload. + Also add a `filepath` header with sftp relative path of downloaded file. + + :param filename: file to download relative to `basedir`. + + :return: processed result + """ + logger.debug("start handling of %s", filename) + if self.real_extensions: + fpath = Path(self.basedir) / filename + for extension in self.real_extensions: + real_fpath = fpath.with_suffix(extension) + if await self.sftphelper.file_exists(str(real_fpath)): + filename = real_fpath.name + break + else: + # If no related files + logger.error( + "No %r related file to %s", + self.real_extensions, str(fpath)) + payload = await self.download_file(filename) + + msg = message.Message() + msg.payload = payload + msg.meta['filepath'] = self.basedir + '/' + filename + + if not self.is_stopped(): + await super().handle(msg) + + async def tick(self): + """ + One iteration of watching. + """ + sftp_ls = await self.sftphelper.list_dir(self.basedir) + logger.critical(sftp_ls) + for filestat in sftp_ls: + logger.critical(filestat) + logger.critical(vars(filestat)) + fname = filestat.filename + if self.re.match(fname): + file_mtime = filestat.attrs.mtime + if self.last_read_mtime < file_mtime: + logger.critical(filestat) + logger.critical(self.last_read_mtime) + logger.critical(filestat.filename) + try: + # TODO: ask if a try/finally here is a good idea + await self.get_file_and_process(fname) + finally: + await self.set_last_read_mtime(file_mtime) + + async def watch_for_file(self): + """ + Watch recursively for ftp new files. + If file match regex, it is downloaded then processed in a message. + """ + while not self.is_stopped(): + await asyncio.sleep(self.interval) + try: + await self.tick() + except Exception as exc: + logger.exception(exc) + raise exc + + +class SFTPFileReader(nodes.BaseNode): + """ + Node to read a file from FTP. + """ + def __init__(self, host="", port=22, credentials=None, hostkey=None, filepath=None, + encoding="utf-8", **kwargs): + super().__init__(**kwargs) + self.filepath = filepath + self.encoding = encoding # If set to None, read as bytes + self.sftphelper = SFTPHelper(host, port, credentials, hostkey) + + async def process(self, msg): + + filepath = nodes.choose_first_not_none( + nodes.callable_or_value(self.filepath, msg), + msg.meta.get('filepath')) + content = await self.sftphelper.download_file(filepath, encoding=self.encoding) + + msg.payload = content + msg.meta['filepath'] = filepath + + return msg + + +class SFTPFileDeleter(nodes.BaseNode): + """ + Node to delete a file from SFTP. + """ + def __init__(self, host="", port=22, credentials=None, hostkey=None, filepath=None, + extensions_to_rm=None, **kwargs): + super().__init__(**kwargs) + self.filepath = filepath + self.sftphelper = SFTPHelper(host, port, credentials, hostkey) + self.extensions_to_rm = extensions_to_rm + + async def process(self, msg): + filepath = nodes.choose_first_not_none( + nodes.callable_or_value(self.filepath, msg), + msg.meta.get('filepath')) + base_fpath = os.path.splitext(msg.meta["filepath"])[0] + try: + await self.sftphelper.delete(filepath) + except Exception: + logger.exception("Trying to delete file %s but doesn't exists", filepath) + if self.extensions_to_rm: + for extension in self.extensions_to_rm: + if await self.sftphelper.file_exists(f"{base_fpath}{extension}"): + await self.sftphelper.delete(f"{base_fpath}{extension}") + logger.debug("Meta file %s deleted", f"{base_fpath}{extension}") + return msg + + +class SFTPFileWriter(nodes.BaseNode): + """ + Node to write content to SFTP. File is first written with `.part` concatenated + to its name then renamed to avoid partial upload. + """ + def __init__(self, host, port=22, credentials=None, hostkey=None, filepath=None, + create_valid_file=False, validation_extension=".ok", encoding="utf-8", + **kwargs): + + super().__init__(**kwargs) + + self.filepath = filepath + self.create_valid_file = create_valid_file + self.validation_extension = validation_extension + self.sftphelper = SFTPHelper(host, port, credentials, hostkey) + self.encoding = encoding + + async def process(self, msg): + + filepath = nodes.choose_first_not_none( + nodes.callable_or_value(self.filepath, msg), + msg.meta.get('filepath')) + content = msg.payload + if isinstance(content, str): + content = content.encode(self.encoding) + + await self.sftphelper.upload_file(filepath + '.part', content) + await self.sftphelper.rename(filepath + '.part', filepath) + if self.create_valid_file: + validation_path = Path(filepath).with_suffix(self.validation_extension) + await self.sftphelper.upload_file(str(validation_path), b"") + + return msg diff --git a/pypeman/tests/test_channel.py b/pypeman/tests/test_channel.py index e1c6fa6..cb078db 100644 --- a/pypeman/tests/test_channel.py +++ b/pypeman/tests/test_channel.py @@ -20,7 +20,8 @@ from pypeman.tests.common import MllPChannelTestThread from pypeman.tests.common import TstException from pypeman.tests.common import TstNode - +from pypeman.tests.test_sftp import MockedSFTPHelper +from pypeman.persistence import MemoryBackend logger = logging.getLogger(__name__) @@ -37,6 +38,10 @@ def raise_exc(msg): raise Exception() +# async def mocked_get_backend(loop): +# return MemoryBackend(loop=loop) + + class ChannelsTests(TestCase): def clean_loop(self): # Useful to execute future callbacks # TODO: remove ? @@ -60,7 +65,7 @@ def clean_msg(self, msg): def start_channels(self): # Start channels for chan in channels.all_channels: - self.loop.run_until_complete(chan.start()) + self.loop.create_task(chan.start()) def setUp(self): # Create class event loop used for tests to avoid failing @@ -1219,3 +1224,27 @@ def send_mllp(host, port, data_to_send): mllp_chan_thread.kill() mllp_chan_thread.join() assert n1.last_input().payload == hl7_strdata + + def test_sftpwatcher_channel(self): + fake_sftp_config = dict(host="fake", port=22, credentials=("fake", "fake")) + ftest_dir = Path(__file__).parent / "data" + ok_fpath = ftest_dir / "testfile.ok" + + with mock.patch('pypeman.contrib.sftp.SFTPHelper', new=MockedSFTPHelper): + chan = channels.SFTPWatcherChannel( + name="sftpchan", regex=".*", loop=self.loop, + basedir=str(ftest_dir), **fake_sftp_config) + n = nodes.Log(name="test_sftpwatch_chan") + chan.add(n) + chan.backend = MemoryBackend() + asyncio.run(chan.backend.start()) + n._reset_test() + self.start_channels() + # raise Exception + self.loop.run_until_complete(chan.tick()) + raise Exception + self.assertEqual(n.last_input(), None) + ok_fpath.touch() + self.loop.run_until_complete(chan.tick()) + self.assertEqual(n.last_input().payload, "testfilecontent") + self.clean_loop() diff --git a/requirements.txt b/requirements.txt index eebafa9..84fbd26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # Mandatory +asyncssh click daemonlite python-dateutil @@ -8,4 +9,3 @@ jsonrpcclient requests # For jsonrpcclient ipython sqlitedict>=2.1 # Python 3.10 DeprecationWarning: setDaemon() is deprecated - diff --git a/setup.py b/setup.py index 4aa36c5..778348c 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ test_suite="nose.collector", install_requires=[ + "asyncssh", "click", "daemonlite", "python-dateutil", @@ -58,10 +59,11 @@ "requests" ], extras_require={ + "sftp": ["asyncssh"], "hl7": ["hl7"], "xml": ["xmltodict"], "time": ["aiocron"], - "all": ["hl7", "xmltodict", "aiocron"] + "all": ["asyncssh", "hl7", "xmltodict", "aiocron"] }, setup_requires=["pytest-runner"], tests_require=[