Skip to content

Commit

Permalink
[AIRFLOW-393] Add callback for FTP downloads
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Kosel committed Sep 28, 2018
1 parent f4f8027 commit 9325330
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
68 changes: 62 additions & 6 deletions airflow/contrib/hooks/ftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def delete_directory(self, path):
conn = self.get_conn()
conn.rmd(path)

def retrieve_file(self, remote_full_path, local_full_path_or_buffer):
def retrieve_file(
self,
remote_full_path,
local_full_path_or_buffer,
callback=None):
"""
Transfers the remote file to a local location.
Expand All @@ -161,23 +165,59 @@ def retrieve_file(self, remote_full_path, local_full_path_or_buffer):
:param local_full_path_or_buffer: full path to the local file or a
file-like buffer
:type local_full_path_or_buffer: str or file-like buffer
:param callback: callback which is called each time a block of data
is read. if you do not use a callback, these blocks will be written
to the file or buffer passed in. if you do pass in a callback, note
that writing to a file or buffer will need to be handled inside the
callback.
[default: output_handle.write()]
:type callback: callable
Example::
hook = FTPHook(ftp_conn_id='my_conn')
remote_path = '/path/to/remote/file'
local_path = '/path/to/local/file'
# with a custom callback (in this case displaying progress on each read)
def print_progress(percent_progress):
self.log.info('Percent Downloaded: %s%%' % percent_progress)
total_downloaded = 0
total_file_size = hook.get_size(remote_path)
output_handle = open(local_path, 'wb')
def write_to_file_with_progress(data):
total_downloaded += len(data)
output_handle.write(data)
percent_progress = (total_downloaded / total_file_size) * 100
print_progress(percent_progress)
hook.retrieve_file(remote_path, None, callback=write_to_file_with_progress)
# without a custom callback data is written to the local_path
hook.retrieve_file(remote_path, local_path)
"""
conn = self.get_conn()

is_path = isinstance(local_full_path_or_buffer, basestring)

if is_path:
output_handle = open(local_full_path_or_buffer, 'wb')
# without a callback, default to writing to a user-provided file or
# file-like buffer
if not callback:
if is_path:
output_handle = open(local_full_path_or_buffer, 'wb')
else:
output_handle = local_full_path_or_buffer
callback = output_handle.write
else:
output_handle = local_full_path_or_buffer
output_handle = None

remote_path, remote_file_name = os.path.split(remote_full_path)
conn.cwd(remote_path)
self.log.info('Retrieving file from FTP: %s', remote_full_path)
conn.retrbinary('RETR %s' % remote_file_name, output_handle.write)
conn.retrbinary('RETR %s' % remote_file_name, callback)
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)

if is_path:
if is_path and output_handle:
output_handle.close()

def store_file(self, remote_full_path, local_full_path_or_buffer):
Expand Down Expand Up @@ -230,6 +270,12 @@ def rename(self, from_name, to_name):
return conn.rename(from_name, to_name)

def get_mod_time(self, path):
"""
Returns a datetime object representing the last time the file was modified
:param path: remote file path
:type path: string
"""
conn = self.get_conn()
ftp_mdtm = conn.sendcmd('MDTM ' + path)
time_val = ftp_mdtm[4:]
Expand All @@ -239,6 +285,16 @@ def get_mod_time(self, path):
except ValueError:
return datetime.datetime.strptime(time_val, '%Y%m%d%H%M%S')

def get_size(self, path):
"""
Returns the size of a file (in bytes)
:param path: remote file path
:type path: string
"""
conn = self.get_conn()
return conn.size(path)


class FTPSHook(FTPHook):

Expand Down
23 changes: 23 additions & 0 deletions tests/contrib/hooks/test_ftp_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#

import mock
import six
import unittest

from airflow.contrib.hooks import ftp_hook as fh
Expand Down Expand Up @@ -101,6 +102,28 @@ def test_mod_time_micro(self):

self.conn_mock.sendcmd.assert_called_once_with('MDTM ' + path)

def test_get_size(self):
self.conn_mock.size.return_value = 1942

path = '/path/file'
with fh.FTPHook() as ftp_hook:
ftp_hook.get_size(path)

self.conn_mock.size.assert_called_once_with(path)

def test_retrieve_file(self):
_buffer = six.StringIO('buffer')
with fh.FTPHook() as ftp_hook:
ftp_hook.retrieve_file(self.path, _buffer)
self.conn_mock.retrbinary.assert_called_once_with('RETR path', _buffer.write)

def test_retrieve_file_with_callback(self):
func = mock.Mock()
_buffer = six.StringIO('buffer')
with fh.FTPHook() as ftp_hook:
ftp_hook.retrieve_file(self.path, _buffer, callback=func)
self.conn_mock.retrbinary.assert_called_once_with('RETR path', func)


if __name__ == '__main__':
unittest.main()

0 comments on commit 9325330

Please sign in to comment.