From f08948fd63b05654379af64b023b63211b171e42 Mon Sep 17 00:00:00 2001 From: jj Date: Mon, 23 Jan 2023 11:30:18 +0100 Subject: [PATCH] iter content in requests seems to work better (both image serrvice and git) --- image_transfer.py | 20 +++++++------------- plainhttp.py | 25 ++----------------------- tests/test_http2ssh.py | 8 +++----- 3 files changed, 12 insertions(+), 41 deletions(-) diff --git a/image_transfer.py b/image_transfer.py index 33db705..8396b81 100644 --- a/image_transfer.py +++ b/image_transfer.py @@ -18,6 +18,7 @@ def file_exist(sftp, name): except: return -1 + def http2ssh(url: str, ssh_client, remote_name: str, force=True): sftp_client = ssh_client.open_sftp() size = file_exist(sftp=sftp_client, name=remote_name) @@ -30,26 +31,19 @@ def http2ssh(url: str, ssh_client, remote_name: str, force=True): dirname = os.path.dirname(remote_name) ssh_client.exec_command(command=f"mkdir -p {dirname}") ssh_client.exec_command(command=f"touch {remote_name}") - + with requests.get(url, stream=True, verify=False, timeout=(2,3)) as r: written = 0 - with sftp_client.open(remote_name, 'w') as f: + with sftp_client.open(remote_name, 'wb') as f: f.set_pipelined(pipelined=True) - while True: - chunk=r.raw.read(1024 * 1000) - if not chunk: - break + for chunk in r.iter_content(chunk_size=1024*1000): + written+=len(chunk) content_to_write = memoryview(chunk) f.write(content_to_write) - written+=len(chunk) - cl = r.headers.get('Content-Length', 0) - print(f"Written: {written} Content-Length: {cl}") - if cl!=written: - print('Content length mismatch') - return -1 + + print(f"Written {written} bytes") return 0 - @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def transfer_image(): diff --git a/plainhttp.py b/plainhttp.py index 267d675..26194d6 100644 --- a/plainhttp.py +++ b/plainhttp.py @@ -8,33 +8,12 @@ from airflow.utils.dates import days_ago from decors import get_connection, remove, setup -from image_transfer import file_exist +from image_transfer import http2ssh default_args = { 'owner': 'airflow', } -def http2ssh(url: str, ssh_client, remote_name: str, force=True): - sftp_client = ssh_client.open_sftp() - size = file_exist(sftp=sftp_client, name=remote_name) - if size>0: - print(f"File {remote_name} exists and has {size} bytes") - if force is not True: - return 0 - print("Forcing overwrite") - - dirname = os.path.dirname(remote_name) - ssh_client.exec_command(command=f"mkdir -p {dirname}") - ssh_client.exec_command(command=f"touch {remote_name}") - - with requests.get(url, stream=True, verify=False, timeout=(2,3)) as r: - with sftp_client.open(remote_name, 'w') as f: - f.set_pipelined(pipelined=True) - for chunk in r.iter_content(chunk_size=1024*1000): - content_to_write = memoryview(chunk) - f.write(content_to_write) - - return 0 @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['wp4', 'http', 'ssh']) def plainhttp2ssh(): @@ -46,7 +25,7 @@ def stream_upload(connection_id, **kwargs): target = params.get('target', '/tmp/') url = params.get('url', '') if not url: - print('Provide valid url') + print('Provide a valid url') return -1 print(f"Putting {url} --> {target}") diff --git a/tests/test_http2ssh.py b/tests/test_http2ssh.py index c38581a..a0df16c 100755 --- a/tests/test_http2ssh.py +++ b/tests/test_http2ssh.py @@ -51,8 +51,7 @@ def test_actual_cpy(self, exists, get): my_client.exec_command = exec - get().__enter__().raw.read = MagicMock(side_effect=[b'blabla', None]) - get().__enter__().headers.get = MagicMock(return_value=6) + get().__enter__().iter_content = MagicMock(return_value=[b'blabla']) r = http2ssh(url='foo.bar', ssh_client=my_client, remote_name='/goo/bar', force=True) self.assertEqual(r, 0) exec.assert_called() @@ -70,10 +69,9 @@ def test_missed_cpy(self, exists, get): my_client.open_sftp.return_value = my_sftp my_client.exec_command = exec + get().__enter__().iter_content = MagicMock(return_value=[b'blabla']) - get().__enter__().raw.read = MagicMock(side_effect=[b'blabla', None]) - get().__enter__().headers.get = MagicMock(return_value=699) r = http2ssh(url='foo.bar', ssh_client=my_client, remote_name='/goo/bar', force=True) - self.assertEqual(r, -1) + self.assertEqual(r, 0) exec.assert_called() wrt.assert_called_once_with(memoryview(b'blabla'))