Skip to content

Commit

Permalink
[client rework]: Modify fetch as a context manager
Browse files Browse the repository at this point in the history
Experiment with @contextmanager decorator on RequestsFetcher.fetch()
in order to avoid unclosed connections.

Signed-off-by: Teodora Sechkova <tsechkova@vmware.com>
  • Loading branch information
sechkova committed Feb 24, 2021
1 parent f3bf5f5 commit cea0c10
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 95 deletions.
70 changes: 35 additions & 35 deletions tuf/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,42 +195,42 @@ def _download_file(url, required_length, fetcher, STRICT_REQUIRED_LENGTH=True):
average_download_speed = 0
number_of_bytes_received = 0

try:
chunks = fetcher.fetch(url, required_length)
start_time = timeit.default_timer()
for chunk in chunks:
with fetcher.fetch(url, required_length) as chunks:
try:
start_time = timeit.default_timer()
for chunk in chunks:

stop_time = timeit.default_timer()
temp_file.write(chunk)

# Measure the average download speed.
number_of_bytes_received += len(chunk)
seconds_spent_receiving = stop_time - start_time
average_download_speed = number_of_bytes_received / seconds_spent_receiving

if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
logger.debug('The average download speed dropped below the minimum'
' average download speed set in tuf.settings.py. Stopping the'
' download!')
break

else:
logger.debug('The average download speed has not dipped below the'
' minimum average download speed set in tuf.settings.py.')

# Does the total number of downloaded bytes match the required length?
_check_downloaded_length(number_of_bytes_received, required_length,
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
average_download_speed=average_download_speed)

except Exception:
# Close 'temp_file'. Any written data is lost.
temp_file.close()
logger.debug('Could not download URL: ' + repr(url))
raise

stop_time = timeit.default_timer()
temp_file.write(chunk)

# Measure the average download speed.
number_of_bytes_received += len(chunk)
seconds_spent_receiving = stop_time - start_time
average_download_speed = number_of_bytes_received / seconds_spent_receiving

if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
logger.debug('The average download speed dropped below the minimum'
' average download speed set in tuf.settings.py. Stopping the'
' download!')
break

else:
logger.debug('The average download speed has not dipped below the'
' minimum average download speed set in tuf.settings.py.')

# Does the total number of downloaded bytes match the required length?
_check_downloaded_length(number_of_bytes_received, required_length,
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
average_download_speed=average_download_speed)

except Exception:
# Close 'temp_file'. Any written data is lost.
temp_file.close()
logger.debug('Could not download URL: ' + repr(url))
raise

else:
return temp_file
else:
return temp_file



Expand Down
134 changes: 74 additions & 60 deletions tuf/requests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import six
import logging
import time
from contextlib import contextmanager

import urllib3.exceptions
import tuf.exceptions
Expand Down Expand Up @@ -52,71 +53,84 @@ def __init__(self):
# minimize subtle security issues. Some cookies may not be HTTP-safe.
self._sessions = {}


# @contextmanager
# def managed_resource(*args, **kwds):
# # Code to acquire resource, e.g.:
# resource = acquire_resource(*args, **kwds)
# try:
# yield resource
# finally:
# # Code to release resource, e.g.:
# release_resource(resource)

@contextmanager
def fetch(self, url, required_length):
# Get a customized session for each new schema+hostname combination.
session = self._get_session(url)

# Get the requests.Response object for this URL.
#
# Defer downloading the response body with stream=True.
# Always set the timeout. This timeout value is interpreted by requests as:
# - connect timeout (max delay before first byte is received)
# - read (gap) timeout (max delay between bytes received)
response = session.get(url, stream=True,
timeout=tuf.settings.SOCKET_TIMEOUT)
# Check response status.
try:
response.raise_for_status()
except requests.HTTPError as e:
status = e.response.status_code
raise tuf.exceptions.FetcherHTTPError(str(e), status)
# Get a customized session for each new schema+hostname combination.
session = self._get_session(url)

# Get the requests.Response object for this URL.
#
# Defer downloading the response body with stream=True.
# Always set the timeout. This timeout value is interpreted by requests as:
# - connect timeout (max delay before first byte is received)
# - read (gap) timeout (max delay between bytes received)
response = session.get(url, stream=True,
timeout=tuf.settings.SOCKET_TIMEOUT)
# Check response status.
try:
response.raise_for_status()
except requests.HTTPError as e:
status = e.response.status_code
raise tuf.exceptions.FetcherHTTPError(str(e), status)


# Define a generator function to be returned by fetch. This way the caller
# of fetch can differentiate between connection and actual data download
# and measure download times accordingly.
def chunks():
try:
bytes_received = 0
while True:
# We download a fixed chunk of data in every round. This is so that we
# can defend against slow retrieval attacks. Furthermore, we do not wish
# to download an extremely large file in one shot.
# Before beginning the round, sleep (if set) for a short amount of time
# so that the CPU is not hogged in the while loop.
if tuf.settings.SLEEP_BEFORE_ROUND:
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)

read_amount = min(
tuf.settings.CHUNK_SIZE, required_length - bytes_received)

# NOTE: This may not handle some servers adding a Content-Encoding
# header, which may cause urllib3 to misbehave:
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
data = response.raw.read(read_amount)
bytes_received += len(data)

# We might have no more data to read. Check number of bytes downloaded.
if not data:
logger.debug('Downloaded ' + repr(bytes_received) + '/' +
repr(required_length) + ' bytes.')

# Finally, we signal that the download is complete.
break

yield data

if bytes_received >= required_length:
break

except urllib3.exceptions.ReadTimeoutError as e:
raise tuf.exceptions.SlowRetrievalError(str(e))

finally:
response.close()
# Define a generator function to be returned by fetch. This way the caller
# of fetch can differentiate between connection and actual data download
# and measure download times accordingly.
def chunks():
try:
bytes_received = 0
while True:
# We download a fixed chunk of data in every round. This is so that we
# can defend against slow retrieval attacks. Furthermore, we do not wish
# to download an extremely large file in one shot.
# Before beginning the round, sleep (if set) for a short amount of time
# so that the CPU is not hogged in the while loop.
if tuf.settings.SLEEP_BEFORE_ROUND:
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)

read_amount = min(
tuf.settings.CHUNK_SIZE, required_length - bytes_received)

# NOTE: This may not handle some servers adding a Content-Encoding
# header, which may cause urllib3 to misbehave:
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
data = response.raw.read(read_amount)
bytes_received += len(data)

return chunks()
# We might have no more data to read. Check number of bytes downloaded.
if not data:
logger.debug('Downloaded ' + repr(bytes_received) + '/' +
repr(required_length) + ' bytes.')

# Finally, we signal that the download is complete.
break

yield data

if bytes_received >= required_length:
break

response.close()

except urllib3.exceptions.ReadTimeoutError as e:
raise tuf.exceptions.SlowRetrievalError(str(e))

yield chunks()

finally:
response.close()



Expand Down

0 comments on commit cea0c10

Please sign in to comment.