Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retries to file download to avoid throwing an exception in the case of normal retry errors #1002

Merged
merged 11 commits into from
Jul 29, 2023
35 changes: 27 additions & 8 deletions planet/clients/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
import uuid

from tqdm.asyncio import tqdm

from ..data_filter import empty_filter
from .. import exceptions
from ..constants import PLANET_BASE_URL
from ..http import Session
from ..models import Paged, StreamingBody
from ..models import Paged
from ..specs import validate_data_item_type

BASE_URL = f'{PLANET_BASE_URL}/data/v1/'
Expand Down Expand Up @@ -595,13 +597,30 @@ async def download_asset(self,
raise exceptions.ClientError(
'asset missing ["location"] entry. Is asset active?')

async with self._session.stream(method='GET', url=location) as resp:
body = StreamingBody(resp)
dl_path = Path(directory, filename or body.name)
dl_path.parent.mkdir(exist_ok=True, parents=True)
await body.write(dl_path,
overwrite=overwrite,
progress_bar=progress_bar)
response = await self._session.request(method='GET', url=location)
filename = filename or response.filename
if not filename:
raise exceptions.ClientError(
f'Could not determine filename at {location}')

dl_path = Path(directory, filename)
dl_path.parent.mkdir(exist_ok=True, parents=True)
LOGGER.info(f'Downloading {dl_path}')

try:
mode = 'wb' if overwrite else 'xb'
with open(dl_path, mode) as fp:
with tqdm(total=response.length,
unit_scale=True,
unit_divisor=1024 * 1024,
unit='B',
desc=str(filename),
disable=not progress_bar) as progress:
update = progress.update if progress_bar else LOGGER.debug
await self._session.write(location, fp, update)
except FileExistsError:
LOGGER.info(f'File {dl_path} exists, not overwriting')

return dl_path

@staticmethod
Expand Down
40 changes: 31 additions & 9 deletions planet/clients/orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
"""Functionality for interacting with the orders api"""
import asyncio
import logging
from pathlib import Path
import time
from typing import AsyncIterator, Callable, List, Optional
import uuid
import json
import hashlib

from pathlib import Path
from tqdm.asyncio import tqdm

from .. import exceptions
from ..constants import PLANET_BASE_URL
from ..http import Session
from ..models import Paged, StreamingBody
from ..models import Paged

BASE_URL = f'{PLANET_BASE_URL}/compute/ops'
STATS_PATH = '/stats/orders/v2'
Expand Down Expand Up @@ -251,14 +253,34 @@ async def download_asset(self,

Raises:
planet.exceptions.APIError: On API error.
planet.exceptions.ClientError: If location is not valid or retry
limit is exceeded.

"""
async with self._session.stream(method='GET', url=location) as resp:
body = StreamingBody(resp)
dl_path = Path(directory, filename or body.name)
dl_path.parent.mkdir(exist_ok=True, parents=True)
await body.write(dl_path,
overwrite=overwrite,
progress_bar=progress_bar)
response = await self._session.request(method='GET', url=location)
filename = filename or response.filename
length = response.length
if not filename:
raise exceptions.ClientError(
f'Could not determine filename at {location}')

dl_path = Path(directory, filename)
dl_path.parent.mkdir(exist_ok=True, parents=True)
LOGGER.info(f'Downloading {dl_path}')

try:
mode = 'wb' if overwrite else 'xb'
with open(dl_path, mode) as fp:
with tqdm(total=length,
unit_scale=True,
unit_divisor=1024 * 1024,
unit='B',
desc=str(filename),
disable=not progress_bar) as progress:
await self._session.write(location, fp, progress.update)
except FileExistsError:
LOGGER.info(f'File {dl_path} exists, not overwriting')

return dl_path

async def download_order(self,
Expand Down
44 changes: 25 additions & 19 deletions planet/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from __future__ import annotations # https://stackoverflow.com/a/33533514
import asyncio
from collections import Counter
from contextlib import asynccontextmanager
from http import HTTPStatus
import logging
import random
import time
from typing import AsyncGenerator, Optional
from typing import Callable, Optional

import httpx
from typing_extensions import Literal
Expand All @@ -42,7 +41,7 @@
httpx.ReadTimeout,
httpx.RemoteProtocolError,
exceptions.BadGateway,
exceptions.TooManyRequests
exceptions.TooManyRequests,
]
MAX_RETRIES = 5
MAX_RETRY_BACKOFF = 64 # seconds
Expand Down Expand Up @@ -327,6 +326,7 @@ async def _retry(self, func, *a, **kw):
LOGGER.info(f'Retrying: sleeping {wait_time}s')
await asyncio.sleep(wait_time)
else:
LOGGER.info('Retrying: failed')
raise e

self.outcomes.update(['Successful'])
Expand Down Expand Up @@ -394,26 +394,32 @@ async def _send(self, request, stream=False) -> httpx.Response:

return http_resp

@asynccontextmanager
async def stream(
self, method: str,
url: str) -> AsyncGenerator[models.StreamingResponse, None]:
"""Submit a request and get the response as a stream context manager.
async def write(self, url: str, fp, callback: Optional[Callable] = None):
"""Write data to local file with limiting and retries.

Parameters:
method: HTTP request method.
url: Location of the API endpoint.
url: Remote location url.
fp: Open write file pointer.
callback: Function that handles write progress updates.

Raises:
planet.exceptions.APIException: On API error.

Returns:
Context manager providing the streaming response.
"""
request = self._client.build_request(method=method, url=url)
http_response = await self._retry(self._send, request, stream=True)
response = models.StreamingResponse(http_response)
try:
yield response
finally:
await response.aclose()

async def _limited_write():
async with self._limiter:
async with self._client.stream('GET', url) as response:
previous = response.num_bytes_downloaded

async for chunk in response.aiter_bytes():
fp.write(chunk)
current = response.num_bytes_downloaded
if callback is not None:
callback(current - previous)
previous = current

await self._retry(_limited_write)

def client(self,
name: Literal['data', 'orders', 'subscriptions'],
Expand Down
155 changes: 31 additions & 124 deletions planet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
# limitations under the License.
"""Manage data for requests and responses."""
import logging
import mimetypes
from pathlib import Path
import random
import re
import string
from typing import AsyncGenerator, Callable, List, Optional
from urllib.parse import urlparse

import httpx
from tqdm.asyncio import tqdm

from .exceptions import PagingError

Expand All @@ -49,134 +44,59 @@ def status_code(self) -> int:
"""HTTP status code"""
return self._http_response.status_code

def json(self) -> dict:
"""Response json"""
return self._http_response.json()


class StreamingResponse(Response):

@property
def headers(self) -> httpx.Headers:
return self._http_response.headers

@property
def url(self) -> str:
return str(self._http_response.url)
def filename(self) -> Optional[str]:
"""Name of the download file.

@property
def num_bytes_downloaded(self) -> int:
return self._http_response.num_bytes_downloaded
The filename is None if the response does not represent a download.
"""
filename = None

async def aiter_bytes(self):
async for c in self._http_response.aiter_bytes():
yield c
if self.length is not None: # is a download file
filename = _get_filename_from_response(self._http_response)

async def aclose(self):
await self._http_response.aclose()
return filename

@property
def length(self) -> Optional[int]:
"""Length of the download file.

class StreamingBody:
"""A representation of a streaming resource from the API."""
The length is None if the response does not represent a download.
"""
LOGGER.warning('here')
try:
length = int(self._http_response.headers["Content-Length"])
except KeyError:
length = None
LOGGER.warning(length)
return length

def __init__(self, response: StreamingResponse):
"""Initialize the object.
def json(self) -> dict:
"""Response json"""
return self._http_response.json()

Parameters:
response: Response that was received from the server.
"""
self._response = response

@property
def name(self) -> str:
"""The name of this resource.
def _get_filename_from_response(response) -> Optional[str]:
"""The name of the response resource.

The default is to use the content-disposition header value from the
response. If not found, falls back to resolving the name from the url
or generating a random name with the type from the response.
"""
name = (_get_filename_from_headers(self._response.headers)
or _get_filename_from_url(self._response.url)
or _get_random_filename(
self._response.headers.get('content-type')))
return name

@property
def size(self) -> int:
"""The size of the body."""
return int(self._response.headers['Content-Length'])

async def write(self,
filename: Path,
overwrite: bool = True,
progress_bar: bool = True):
"""Write the body to a file.
Parameters:
filename: Name to assign to downloaded file.
overwrite: Overwrite any existing files.
progress_bar: Show progress bar during download.
"""

class _LOG:

def __init__(self, total, unit, filename, disable):
self.total = total
self.unit = unit
self.disable = disable
self.previous = 0
self.filename = str(filename)

if not self.disable:
LOGGER.debug(f'writing to {self.filename}')

def update(self, new):
if new - self.previous > self.unit and not self.disable:
# LOGGER.debug(f'{new-self.previous}')
perc = int(100 * new / self.total)
LOGGER.debug(f'{self.filename}: '
f'wrote {perc}% of {self.total}')
self.previous = new
name = (_get_filename_from_headers(response.headers)
or _get_filename_from_url(str(response.url)))
return name

unit = 1024 * 1024

mode = 'wb' if overwrite else 'xb'
try:
with open(filename, mode) as fp:
_log = _LOG(self.size,
16 * unit,
filename,
disable=progress_bar)
with tqdm(total=self.size,
unit_scale=True,
unit_divisor=unit,
unit='B',
desc=str(filename),
disable=not progress_bar) as progress:
previous = self._response.num_bytes_downloaded
async for chunk in self._response.aiter_bytes():
fp.write(chunk)
new = self._response.num_bytes_downloaded
_log.update(new)
progress.update(new - previous)
previous = new
except FileExistsError:
LOGGER.info(f'File {filename} exists, not overwriting')


def _get_filename_from_headers(headers):
"""Get a filename from the Content-Disposition header, if available.

:param headers dict: a ``dict`` of response headers
:returns: a filename (i.e. ``basename``)
:rtype: str or None
"""
def _get_filename_from_headers(headers: httpx.Headers) -> Optional[str]:
"""Get a filename from the Content-Disposition header, if available."""
cd = headers.get('content-disposition', '')
match = re.search('filename="?([^"]+)"?', cd)
return match.group(1) if match else None


def _get_filename_from_url(url: str) -> Optional[str]:
"""Get a filename from a url.
"""Get a filename from the url.

Getting a name for Landsat imagery uses this function.
"""
Expand All @@ -185,19 +105,6 @@ def _get_filename_from_url(url: str) -> Optional[str]:
return name or None


def _get_random_filename(content_type=None):
"""Get a pseudo-random, Planet-looking filename.

:returns: a filename (i.e. ``basename``)
:rtype: str
"""
extension = mimetypes.guess_extension(content_type or '') or ''
characters = string.ascii_letters + '0123456789'
letters = ''.join(random.sample(characters, 8))
name = 'planet-{}{}'.format(letters, extension)
return name


class Paged:
"""Asynchronous iterator over results in a paged resource.

Expand Down
Loading