Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception handling in online datapipes #968

Closed
53 changes: 51 additions & 2 deletions test/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def test_gdrive_iterdatapipe(self):
amazon_review_url = "https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM"
expected_file_name = "amazon_review_polarity_csv.tar.gz"
expected_MD5_hash = "fe39f8b653cada45afd5792e0f0e8f9b"
gdrive_reader_dp = GDriveReader(IterableWrapper([amazon_review_url]))
query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True}
timeout = 120
gdrive_reader_dp = GDriveReader(IterableWrapper([amazon_review_url]), timeout=timeout, **query_params)

# Functional Test: test if the GDrive Reader can download and read properly
reader_dp = gdrive_reader_dp.readlines()
Expand All @@ -41,6 +43,27 @@ def test_gdrive_iterdatapipe(self):
gdrive_dp = GDriveReader(source_dp)
self.assertEqual(1, len(gdrive_dp))

# Error Test: test if the GDrive Reader raises an error when the url is invalid
error_url = "https://drive.google.com/uc?export=download&id=filedoesnotexist"
http_error_dp = GDriveReader(IterableWrapper([error_url]), timeout=timeout)
with self.assertRaisesRegex(
Exception, r"404.+https://drive.google.com/uc\?export=download&id=filedoesnotexist"
):
next(iter(http_error_dp.readlines()))

# Feature skip-error Test: test if the GDrive Reader skips urls causing problems
gdrive_skip_error_dp = GDriveReader(
IterableWrapper([error_url, amazon_review_url]), timeout=timeout, skip_on_error=True
)
reader_dp = gdrive_skip_error_dp.readlines()
with self.assertWarnsRegex(
Warning, r"404.+https://drive.google.com/uc\?export=download&id=filedoesnotexist.+skipping"
):
it = iter(reader_dp)
path, line = next(it)
self.assertEqual(expected_file_name, os.path.basename(path))
self.assertTrue(line != b"")

def test_online_iterdatapipe(self):

license_file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
Expand All @@ -49,14 +72,16 @@ def test_online_iterdatapipe(self):
expected_amazon_file_name = "amazon_review_polarity_csv.tar.gz"
expected_license_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c"
expected_amazon_MD5_hash = "fe39f8b653cada45afd5792e0f0e8f9b"
query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True}
timeout = 120

file_hash_dict = {
license_file_url: expected_license_MD5_hash,
expected_amazon_file_name: expected_amazon_MD5_hash,
}

# Functional Test: can read from GDrive links
online_reader_dp = OnlineReader(IterableWrapper([amazon_review_url]))
online_reader_dp = OnlineReader(IterableWrapper([amazon_review_url]), timeout=timeout, **query_params)
reader_dp = online_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
Expand Down Expand Up @@ -89,6 +114,30 @@ def test_online_iterdatapipe(self):
# __len__ Test: returns the length of source DataPipe
self.assertEqual(2, len(online_reader_dp))

# Error Test: test if the Online Reader raises an error when the url is invalid
error_url_http = "https://github.com/pytorch/data/this/url/dont/exist"
online_error_dp = OnlineReader(IterableWrapper([error_url_http]), timeout=timeout)
with self.assertRaisesRegex(Exception, f"404.+{error_url_http}"):
next(iter(online_error_dp.readlines()))

error_url_gdrive = "https://drive.google.com/uc?export=download&id=filedoesnotexist"
online_error_dp = OnlineReader(IterableWrapper([error_url_gdrive]), timeout=timeout)
with self.assertRaisesRegex(
Exception, r"404.+https://drive.google.com/uc\?export=download&id=filedoesnotexist"
):
next(iter(online_error_dp.readlines()))

# Feature skip-error Test: test if the Online Reader skips urls causing problems
online_skip_error_dp = OnlineReader(
IterableWrapper([error_url_http, error_url_gdrive, license_file_url]), timeout=timeout, skip_on_error=True
)
reader_dp = online_skip_error_dp.readlines()
with self.assertWarnsRegex(Warning, f"404.+{error_url_http}.+skipping"):
it = iter(reader_dp)
path, line = next(it)
self.assertEqual(expected_license_file_name, os.path.basename(path))
self.assertTrue(b"BSD" in line)


if __name__ == "__main__":
unittest.main()
30 changes: 25 additions & 5 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
import os
import unittest
import warnings
from unittest.mock import patch

import expecttest

import torchdata

from _utils._common_utils_for_test import check_hash_fn, create_temp_dir, IS_M1, IS_WINDOWS
from torch.utils.data import DataLoader

from torchdata.datapipes.iter import (
EndOnDiskCacheHolder,
FileOpener,
FSSpecFileLister,
FSSpecFileOpener,
Expand All @@ -27,6 +24,7 @@
S3FileLister,
S3FileLoader,
)
from torchdata.datapipes.iter.load.online import _get_proxies

try:
import fsspec
Expand Down Expand Up @@ -100,9 +98,31 @@ def test_http_reader_iterdatapipe(self):
# Error Test: test if the Http Reader raises an error when the url is invalid
error_url = "https://github.com/pytorch/data/this/url/dont/exist"
http_error_dp = HttpReader(IterableWrapper([error_url]), timeout=timeout)
with self.assertRaisesRegex(Exception, "[404]"):
with self.assertRaisesRegex(Exception, f"404.+{error_url}"):
next(iter(http_error_dp.readlines()))

# Feature skip-error Test: test if the Http Reader skips urls causing problems
http_skip_error_dp = HttpReader(IterableWrapper([error_url, file_url]), timeout=timeout, skip_on_error=True)
reader_dp = http_skip_error_dp.readlines()
with self.assertWarnsRegex(Warning, f"404.+{error_url}.+skipping"):
it = iter(reader_dp)
path, line = next(it)
self.assertEqual(expected_file_name, os.path.basename(path))
self.assertTrue(b"BSD" in line)

# test if GET-request is done with correct arguments
with patch("requests.Session.get") as mock_get:
http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params)
_ = next(iter(http_reader_dp))
mock_get.assert_called_with(
file_url,
timeout=timeout,
proxies=_get_proxies(),
stream=True,
auth=query_params["auth"],
allow_redirects=query_params["allow_redirects"],
)

def test_on_disk_cache_holder_iterdatapipe(self):
tar_file_url = "https://raw.githubusercontent.com/pytorch/data/main/test/_fakedata/csv.tar.gz"
expected_file_name = os.path.join(self.temp_dir.name, "csv.tar.gz")
Expand Down
112 changes: 74 additions & 38 deletions torchdata/datapipes/iter/load/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

import re
import urllib

import warnings
from typing import Any, Dict, Iterator, Optional, Tuple

import requests

from requests.exceptions import HTTPError, RequestException

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.utils import StreamWrapper
Expand All @@ -36,21 +34,11 @@ def _get_proxies() -> Optional[Dict[str, str]]:
def _get_response_from_http(
url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]]
) -> Tuple[str, StreamWrapper]:
try:
with requests.Session() as session:
proxies = _get_proxies()
if timeout is None:
r = session.get(url, stream=True, proxies=proxies, **query_params) # type: ignore[arg-type]
else:
r = session.get(url, timeout=timeout, stream=True, proxies=proxies, **query_params) # type: ignore[arg-type]
r.raise_for_status()
return url, StreamWrapper(r.raw)
except HTTPError as e:
raise Exception(f"Could not get the file. [HTTP Error] {e.response}.")
except RequestException as e:
raise Exception(f"Could not get the file at {url}. [RequestException] {e.response}.")
except Exception:
raise
with requests.Session() as session:
proxies = _get_proxies()
r = session.get(url, timeout=timeout, proxies=proxies, stream=True, **query_params) # type: ignore[arg-type]
NivekT marked this conversation as resolved.
Show resolved Hide resolved
r.raise_for_status()
return url, StreamWrapper(r.raw)


@functional_datapipe("read_from_http")
Expand All @@ -62,6 +50,7 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs
timeout: timeout in seconds for HTTP request
skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
Expand All @@ -80,18 +69,26 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
"""

def __init__(
self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None, **kwargs: Optional[Dict[str, Any]]
self,
source_datapipe: IterDataPipe[str],
timeout: Optional[float] = None,
skip_on_error: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> None:
self.source_datapipe: IterDataPipe[str] = source_datapipe
self.timeout = timeout
self.skip_on_error = skip_on_error
self.query_params = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
if self.query_params:
try:
yield _get_response_from_http(url, timeout=self.timeout, **self.query_params)
else:
yield _get_response_from_http(url, timeout=self.timeout)
except Exception as e:
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise

def __len__(self) -> int:
return len(self.source_datapipe)
Expand All @@ -102,14 +99,14 @@ def _extract_gdrive_api_response(content: str) -> Optional[str]:
return match["api_response"] if match is not None else None


def _get_response_from_google_drive(url: str, *, timeout: Optional[float]) -> Tuple[str, StreamWrapper]:
def _get_response_from_google_drive(
url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]]
) -> Tuple[str, StreamWrapper]:
confirm_token = None

with requests.Session() as session:
if timeout is None:
response = session.get(url, stream=True)
else:
response = session.get(url, timeout=timeout, stream=True)
response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[arg-type]
response.raise_for_status()

for k, v in response.cookies.items():
if k.startswith("download_warning"):
Expand All @@ -125,22 +122,21 @@ def _get_response_from_google_drive(url: str, *, timeout: Optional[float]) -> Tu
if confirm_token:
url = url + "&confirm=" + confirm_token

if timeout is None:
response = session.get(url, stream=True)
else:
response = session.get(url, timeout=timeout, stream=True)
response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[arg-type]
response.raise_for_status()

if "content-disposition" not in response.headers:
raise RuntimeError(
"Internal error: headers don't contain content-disposition. This is usually caused by "
f"Google drive link {url} internal error: "
"headers don't contain content-disposition. This is usually caused by "
"using a sharing/viewing link instead of a download link. Click 'Download' on the "
"Google Drive page, which should redirect you to a download page, and use the link "
"of that page."
)

filename = re.findall('filename="(.+)"', response.headers["content-disposition"])
if filename is None:
raise RuntimeError("Filename could not be autodetected")
raise RuntimeError(f"Google drive link {url}: filename could not be autodetected")

return filename[0], StreamWrapper(response.raw)

Expand All @@ -154,6 +150,8 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs to GDrive files
timeout: timeout in seconds for HTTP request
skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
Expand All @@ -169,13 +167,28 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
"""
source_datapipe: IterDataPipe[str]

def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[float] = None) -> None:
def __init__(
self,
source_datapipe: IterDataPipe[str],
*,
timeout: Optional[float] = None,
skip_on_error: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> None:
self.source_datapipe = source_datapipe
self.timeout = timeout
self.skip_on_error = skip_on_error
self.query_params = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
yield _get_response_from_google_drive(url, timeout=self.timeout)
try:
yield _get_response_from_google_drive(url, timeout=self.timeout, **self.query_params)
except Exception as e:
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise

def __len__(self) -> int:
return len(self.source_datapipe)
Expand All @@ -190,6 +203,8 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs
timeout: timeout in seconds for HTTP request
skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
Expand All @@ -205,18 +220,39 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
"""
source_datapipe: IterDataPipe[str]

def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[float] = None) -> None:
def __init__(
self,
source_datapipe: IterDataPipe[str],
*,
timeout: Optional[float] = None,
skip_on_error: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> None:
self.source_datapipe = source_datapipe
self.timeout = timeout
self.skip_on_error = skip_on_error
self.query_params = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
parts = urllib.parse.urlparse(url)

if re.match(r"(drive|docs)[.]google[.]com", parts.netloc):
yield _get_response_from_google_drive(url, timeout=self.timeout)
try:
yield _get_response_from_google_drive(url, timeout=self.timeout, **self.query_params)
except Exception as e:
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise
else:
yield _get_response_from_http(url, timeout=self.timeout)
try:
yield _get_response_from_http(url, timeout=self.timeout, **self.query_params)
except Exception as e:
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise

def __len__(self) -> int:
return len(self.source_datapipe)