diff --git a/parlai/core/build_data.py b/parlai/core/build_data.py index 9a707905154..6a9eda3519d 100644 --- a/parlai/core/build_data.py +++ b/parlai/core/build_data.py @@ -21,6 +21,7 @@ import tqdm import gzip import math +import contextlib import parlai.utils.logging as logging from parlai.utils.io import PathManager @@ -30,6 +31,17 @@ from multiprocessing import Pool +try: + # internal infra requires special attention to use http sessions + from parlai_fb import get_http_session +except (ImportError, AttributeError): + + @contextlib.contextmanager + def get_http_session(): + with requests.Session() as session: + yield session + + class DownloadableFile: """ A class used to abstract any file that has to be downloaded online. @@ -95,17 +107,20 @@ def check_header(self): """ Performs a HEAD request to check if the URL / Google Drive ID is live. """ - session = requests.Session() - if self.from_google: - URL = 'https://docs.google.com/uc?export=download' - response = session.head(URL, params={'id': self.url}, stream=True) - else: - headers = { - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/77.0.3865.90 Safari/537.36' - } - response = session.head(self.url, allow_redirects=True, headers=headers) - status = response.status_code - session.close() + with get_http_session() as session: + if self.from_google: + URL = 'https://docs.google.com/uc?export=download' + response = session.head(URL, params={'id': self.url}, stream=True) + else: + headers = { + 'User-Agent': ( + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_6) ' + 'AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/77.0.3865.90 Safari/537.36' + ) + } + response = session.head(self.url, allow_redirects=True, headers=headers) + status = response.status_code assert status == 200 @@ -166,7 +181,7 @@ def download(url, path, fname, redownload=False, num_retries=5): while download and retry > 0: response = None - with requests.Session() as session: + with get_http_session() as session: try: response = session.get(url, stream=True, timeout=5) @@ -389,7 +404,7 @@ def download_from_google_drive(gd_id, destination): """ URL = 'https://docs.google.com/uc?export=download' - with requests.Session() as session: + with get_http_session() as session: response = session.get(URL, params={'id': gd_id}, stream=True) token = _get_confirm_token(response) diff --git a/parlai/utils/io.py b/parlai/utils/io.py index 3e20c4de431..e1ee44cfebe 100644 --- a/parlai/utils/io.py +++ b/parlai/utils/io.py @@ -23,8 +23,7 @@ # register any internal file handlers import parlai_fb # noqa: F401 - parlai_fb.finalize_registration(PathManager) # internal file handlers can't handle atomic saving. see T71772714 - USE_ATOMIC_TORCH_SAVE = False + USE_ATOMIC_TORCH_SAVE = not parlai_fb.finalize_registration(PathManager) except ModuleNotFoundError: - pass + USE_ATOMIC_TORCH_SAVE = True