Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception handling in online datapipes #968

Closed
46 changes: 44 additions & 2 deletions test/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def test_gdrive_iterdatapipe(self):
amazon_review_url = "https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM"
expected_file_name = "amazon_review_polarity_csv.tar.gz"
expected_MD5_hash = "fe39f8b653cada45afd5792e0f0e8f9b"
gdrive_reader_dp = GDriveReader(IterableWrapper([amazon_review_url]))
query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True}
timeout = 120
gdrive_reader_dp = GDriveReader(IterableWrapper([amazon_review_url]), timeout=timeout, **query_params)

# Functional Test: test if the GDrive Reader can download and read properly
reader_dp = gdrive_reader_dp.readlines()
Expand All @@ -41,6 +43,21 @@ def test_gdrive_iterdatapipe(self):
gdrive_dp = GDriveReader(source_dp)
self.assertEqual(1, len(gdrive_dp))

# Error Test: test if the GDrive Reader raises an error when the url is invalid
error_url = "https://drive.google.com/uc?export=download&id=filedoesnotexist"
http_error_dp = GDriveReader(IterableWrapper([error_url]))
with self.assertRaisesRegex(Exception, "[404]"):
next(iter(http_error_dp.readlines()))

# Feature skip-error Test: test if the GDrive Reader skips urls causing problems
gdrive_skip_error_dp = GDriveReader(IterableWrapper([error_url, amazon_review_url]), skip_on_error=True)
reader_dp = gdrive_skip_error_dp.readlines()
with self.assertWarnsRegex(Warning, f"404.+{error_url}.+skipping"):
it = iter(reader_dp)
path, line = next(it)
self.assertEqual(expected_file_name, os.path.basename(path))
self.assertTrue(line != b"")

def test_online_iterdatapipe(self):

license_file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
Expand All @@ -49,14 +66,16 @@ def test_online_iterdatapipe(self):
expected_amazon_file_name = "amazon_review_polarity_csv.tar.gz"
expected_license_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c"
expected_amazon_MD5_hash = "fe39f8b653cada45afd5792e0f0e8f9b"
query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True}
timeout = 120

file_hash_dict = {
license_file_url: expected_license_MD5_hash,
expected_amazon_file_name: expected_amazon_MD5_hash,
}

# Functional Test: can read from GDrive links
online_reader_dp = OnlineReader(IterableWrapper([amazon_review_url]))
online_reader_dp = OnlineReader(IterableWrapper([amazon_review_url]), timeout=timeout, **query_params)
reader_dp = online_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
Expand Down Expand Up @@ -89,6 +108,29 @@ def test_online_iterdatapipe(self):
# __len__ Test: returns the length of source DataPipe
self.assertEqual(2, len(online_reader_dp))

# Error Test: test if the Online Reader raises an error when the url is invalid
error_url_http = "https://github.com/pytorch/data/this/url/dont/exist"
online_skip_error_dp = OnlineReader(IterableWrapper([error_url_http]))
with self.assertRaisesRegex(Exception, f"404.+{error_url_http}"):
next(iter(online_skip_error_dp.readlines()))

error_url_gdrive = "https://drive.google.com/uc?export=download&id=filedoesnotexist"
online_skip_error_dp = OnlineReader(IterableWrapper([error_url_gdrive]))
with self.assertRaisesRegex(Exception, f"404.+{error_url_gdrive}"):
next(iter(online_skip_error_dp.readlines()))

# Feature skip-error Test: test if the Online Reader skips urls causing problems
error_url_http = "https://github.com/pytorch/data/this/url/dont/exist"
online_skip_error_dp = OnlineReader(
IterableWrapper([error_url_http, error_url_gdrive, license_file_url]), skip_on_error=True
)
reader_dp = online_skip_error_dp.readlines()
with self.assertWarnsRegex(Warning, f"404.+{error_url_gdrive}.+skipping"):
it = iter(reader_dp)
path, line = next(it)
self.assertEqual(expected_license_file_name, os.path.basename(path))
self.assertTrue(b"BSD" in line)


if __name__ == "__main__":
unittest.main()
30 changes: 25 additions & 5 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
import os
import unittest
import warnings
from unittest.mock import patch

import expecttest

import torchdata

from _utils._common_utils_for_test import check_hash_fn, create_temp_dir, IS_M1, IS_WINDOWS
from torch.utils.data import DataLoader

from torchdata.datapipes.iter import (
EndOnDiskCacheHolder,
FileOpener,
FSSpecFileLister,
FSSpecFileOpener,
Expand All @@ -27,6 +24,7 @@
S3FileLister,
S3FileLoader,
)
from torchdata.datapipes.iter.load.online import _get_proxies

try:
import fsspec
Expand Down Expand Up @@ -100,9 +98,31 @@ def test_http_reader_iterdatapipe(self):
# Error Test: test if the Http Reader raises an error when the url is invalid
error_url = "https://github.com/pytorch/data/this/url/dont/exist"
http_error_dp = HttpReader(IterableWrapper([error_url]), timeout=timeout)
with self.assertRaisesRegex(Exception, "[404]"):
with self.assertRaisesRegex(Exception, f"404.+{error_url}"):
next(iter(http_error_dp.readlines()))

# Feature skip-error Test: test if the Http Reader skips urls causing problems
http_skip_error_dp = HttpReader(IterableWrapper([error_url, file_url]), timeout=timeout, skip_on_error=True)
reader_dp = http_skip_error_dp.readlines()
with self.assertWarnsRegex(Warning, f"404.+{error_url}.+skipping"):
it = iter(reader_dp)
path, line = next(it)
self.assertEqual(expected_file_name, os.path.basename(path))
self.assertTrue(b"BSD" in line)

# test if GET-request is done with correct arguments
with patch("requests.Session.get") as mock_get:
http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params)
_ = next(iter(http_reader_dp))
mock_get.assert_called_with(
file_url,
timeout=timeout,
proxies=_get_proxies(),
stream=True,
auth=query_params["auth"],
allow_redirects=query_params["allow_redirects"],
)

def test_on_disk_cache_holder_iterdatapipe(self):
tar_file_url = "https://raw.githubusercontent.com/pytorch/data/main/test/_fakedata/csv.tar.gz"
expected_file_name = os.path.join(self.temp_dir.name, "csv.tar.gz")
Expand Down
115 changes: 73 additions & 42 deletions torchdata/datapipes/iter/load/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

import re
import urllib

import warnings
from typing import Any, Dict, Iterator, Optional, Tuple

import requests

from requests.exceptions import HTTPError, RequestException
from requests.exceptions import RequestException

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
Expand All @@ -36,21 +35,11 @@ def _get_proxies() -> Optional[Dict[str, str]]:
def _get_response_from_http(
url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]]
) -> Tuple[str, StreamWrapper]:
try:
with requests.Session() as session:
proxies = _get_proxies()
if timeout is None:
r = session.get(url, stream=True, proxies=proxies, **query_params) # type: ignore[arg-type]
else:
r = session.get(url, timeout=timeout, stream=True, proxies=proxies, **query_params) # type: ignore[arg-type]
r.raise_for_status()
return url, StreamWrapper(r.raw)
except HTTPError as e:
raise Exception(f"Could not get the file. [HTTP Error] {e.response}.")
except RequestException as e:
raise Exception(f"Could not get the file at {url}. [RequestException] {e.response}.")
except Exception:
raise
with requests.Session() as session:
proxies = _get_proxies()
r = session.get(url, timeout=timeout, proxies=proxies, stream=True, **query_params) # type: ignore[arg-type]
NivekT marked this conversation as resolved.
Show resolved Hide resolved
r.raise_for_status()
return url, StreamWrapper(r.raw)


@functional_datapipe("read_from_http")
Expand All @@ -62,6 +51,7 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs
timeout: timeout in seconds for HTTP request
skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
Expand All @@ -80,18 +70,26 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
"""

def __init__(
self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None, **kwargs: Optional[Dict[str, Any]]
self,
source_datapipe: IterDataPipe[str],
timeout: Optional[float] = None,
skip_on_error: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> None:
self.source_datapipe: IterDataPipe[str] = source_datapipe
self.timeout = timeout
self.skip_on_error = skip_on_error
self.query_params = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
if self.query_params:
try:
yield _get_response_from_http(url, timeout=self.timeout, **self.query_params)
else:
yield _get_response_from_http(url, timeout=self.timeout)
except RequestException as e:
SvenDS9 marked this conversation as resolved.
Show resolved Hide resolved
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise

def __len__(self) -> int:
return len(self.source_datapipe)
Expand All @@ -102,14 +100,14 @@ def _extract_gdrive_api_response(content: str) -> Optional[str]:
return match["api_response"] if match is not None else None


def _get_response_from_google_drive(url: str, *, timeout: Optional[float]) -> Tuple[str, StreamWrapper]:
def _get_response_from_google_drive(
url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]]
) -> Tuple[str, StreamWrapper]:
confirm_token = None

with requests.Session() as session:
if timeout is None:
response = session.get(url, stream=True)
else:
response = session.get(url, timeout=timeout, stream=True)
response = session.get(url, timeout=timeout, stream=True, **query_params)
response.raise_for_status()

for k, v in response.cookies.items():
if k.startswith("download_warning"):
Expand All @@ -125,22 +123,21 @@ def _get_response_from_google_drive(url: str, *, timeout: Optional[float]) -> Tu
if confirm_token:
url = url + "&confirm=" + confirm_token

if timeout is None:
response = session.get(url, stream=True)
else:
response = session.get(url, timeout=timeout, stream=True)
response = session.get(url, timeout=timeout, stream=True, **query_params)
response.raise_for_status()

if "content-disposition" not in response.headers:
raise RuntimeError(
"Internal error: headers don't contain content-disposition. This is usually caused by "
f"Google drive link {url} internal error: "
"headers don't contain content-disposition. This is usually caused by "
"using a sharing/viewing link instead of a download link. Click 'Download' on the "
"Google Drive page, which should redirect you to a download page, and use the link "
"of that page."
)

filename = re.findall('filename="(.+)"', response.headers["content-disposition"])
if filename is None:
raise RuntimeError("Filename could not be autodetected")
raise RuntimeError(f"Google drive link {url}: filename could not be autodetected")

return filename[0], StreamWrapper(response.raw)

Expand All @@ -154,6 +151,8 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs to GDrive files
timeout: timeout in seconds for HTTP request
skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
Expand All @@ -169,13 +168,28 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
"""
source_datapipe: IterDataPipe[str]

def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[float] = None) -> None:
def __init__(
self,
source_datapipe: IterDataPipe[str],
*,
timeout: Optional[float] = None,
skip_on_error: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> None:
self.source_datapipe = source_datapipe
self.timeout = timeout
self.skip_on_error = skip_on_error
self.query_params = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
yield _get_response_from_google_drive(url, timeout=self.timeout)
try:
yield _get_response_from_google_drive(url, timeout=self.timeout, **self.query_params)
except Exception as e:
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise

def __len__(self) -> int:
return len(self.source_datapipe)
Expand All @@ -190,6 +204,8 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: a DataPipe that contains URLs
timeout: timeout in seconds for HTTP request
skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
Expand All @@ -205,18 +221,33 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
"""
source_datapipe: IterDataPipe[str]

def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[float] = None) -> None:
def __init__(
self,
source_datapipe: IterDataPipe[str],
*,
timeout: Optional[float] = None,
skip_on_error: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> None:
self.source_datapipe = source_datapipe
self.timeout = timeout
self.skip_on_error = skip_on_error
self.query_params = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
parts = urllib.parse.urlparse(url)

if re.match(r"(drive|docs)[.]google[.]com", parts.netloc):
yield _get_response_from_google_drive(url, timeout=self.timeout)
else:
yield _get_response_from_http(url, timeout=self.timeout)
try:
parts = urllib.parse.urlparse(url)

if re.match(r"(drive|docs)[.]google[.]com", parts.netloc):
yield _get_response_from_google_drive(url, timeout=self.timeout, **self.query_params)
else:
yield _get_response_from_http(url, timeout=self.timeout, **self.query_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please wrap try-except around _get_response_from_google_drive or _get_response_from_http?
In your current implementation, there is a chance the skipped Error comes from parts = urllib.parse.urlparse(url) or re.match(r"(drive|docs)[.]google[.]com", parts.netloc).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will import and merge this after this change. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have implemented this change, but I don't really understand why it is necessary. In my opinion the source of the exception doesn't really matter if we want to skip over them anyway. With this change exceptions caused by trying to parse the url will not be caught.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change exceptions caused by trying to parse the url will not be caught.

I think that is the point. If the URL cannot be parsed, perhaps users want to know and fix it. If you cannot get a response, then they may want to skip it.

except Exception as e:
if self.skip_on_error:
warnings.warn(f"{e}, skipping...")
else:
raise

def __len__(self) -> int:
return len(self.source_datapipe)