Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Handle some fbcode infra nuances. #3994

Merged
merged 1 commit into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions parlai/core/build_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tqdm
import gzip
import math
import contextlib
import parlai.utils.logging as logging
from parlai.utils.io import PathManager

Expand All @@ -30,6 +31,17 @@
from multiprocessing import Pool


try:
# internal infra requires special attention to use http sessions
from parlai_fb import get_http_session
except (ImportError, AttributeError):

@contextlib.contextmanager
def get_http_session():
with requests.Session() as session:
yield session


class DownloadableFile:
"""
A class used to abstract any file that has to be downloaded online.
Expand Down Expand Up @@ -95,17 +107,20 @@ def check_header(self):
"""
Performs a HEAD request to check if the URL / Google Drive ID is live.
"""
session = requests.Session()
if self.from_google:
URL = 'https://docs.google.com/uc?export=download'
response = session.head(URL, params={'id': self.url}, stream=True)
else:
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/77.0.3865.90 Safari/537.36'
}
response = session.head(self.url, allow_redirects=True, headers=headers)
status = response.status_code
session.close()
with get_http_session() as session:
if self.from_google:
URL = 'https://docs.google.com/uc?export=download'
response = session.head(URL, params={'id': self.url}, stream=True)
else:
headers = {
'User-Agent': (
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_6) '
'AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/77.0.3865.90 Safari/537.36'
)
}
response = session.head(self.url, allow_redirects=True, headers=headers)
status = response.status_code

assert status == 200

Expand Down Expand Up @@ -166,7 +181,7 @@ def download(url, path, fname, redownload=False, num_retries=5):
while download and retry > 0:
response = None

with requests.Session() as session:
with get_http_session() as session:
try:
response = session.get(url, stream=True, timeout=5)

Expand Down Expand Up @@ -389,7 +404,7 @@ def download_from_google_drive(gd_id, destination):
"""
URL = 'https://docs.google.com/uc?export=download'

with requests.Session() as session:
with get_http_session() as session:
response = session.get(URL, params={'id': gd_id}, stream=True)
token = _get_confirm_token(response)

Expand Down
5 changes: 2 additions & 3 deletions parlai/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
# register any internal file handlers
import parlai_fb # noqa: F401

parlai_fb.finalize_registration(PathManager)
# internal file handlers can't handle atomic saving. see T71772714
USE_ATOMIC_TORCH_SAVE = False
USE_ATOMIC_TORCH_SAVE = not parlai_fb.finalize_registration(PathManager)
except ModuleNotFoundError:
pass
USE_ATOMIC_TORCH_SAVE = True