Skip to content

Commit

Permalink
Store hashes of extracted CoNLL2000Chunking files (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Feb 23, 2021
1 parent 57a178f commit 2764143
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
14 changes: 7 additions & 7 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestUtils(TorchtextTestCase):

def test_download_extract_tar(self):
# create root directory for downloading data
root = '.data'
root = os.path.abspath('.data')
if not os.path.exists(root):
os.makedirs(root)

Expand Down Expand Up @@ -47,7 +47,7 @@ def test_download_extract_tar(self):

def test_download_extract_gz(self):
# create root directory for downloading data
root = '.data'
root = os.path.abspath('.data')
if not os.path.exists(root):
os.makedirs(root)

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_download_extract_gz(self):

def test_download_extract_zip(self):
# create root directory for downloading data
root = '.data'
root = os.path.abspath('.data')
if not os.path.exists(root):
os.makedirs(root)

Expand Down Expand Up @@ -110,18 +110,18 @@ def test_download_extract_zip(self):
def test_no_download(self):
asset_name = 'glove.840B.300d.zip'
asset_path = get_asset_path(asset_name)
root = '.data'
root = os.path.abspath('.data')
if not os.path.exists(root):
os.makedirs(root)
data_path = os.path.join('.data', asset_name)
data_path = os.path.abspath(os.path.join('.data', asset_name))
shutil.copy(asset_path, data_path)
file_path = utils.download_from_url('fakedownload/glove.840B.300d.zip')
self.assertEqual(file_path, data_path)
conditional_remove(data_path)

def test_download_extract_to_path(self):
# create root directory for downloading data
root = '.data'
root = os.path.abspath('.data')
if not os.path.exists(root):
os.makedirs(root)

Expand Down Expand Up @@ -157,7 +157,7 @@ def test_download_extract_to_path(self):
@unittest.skip("Download temp. slow.")
def test_extract_non_tar_zip(self):
# create root directory for downloading data
root = '.data'
root = os.path.abspath('.data')
if not os.path.exists(root):
os.makedirs(root)

Expand Down
15 changes: 8 additions & 7 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,19 @@ def new_fn(fn):

def download_extract_validate(root, url, url_md5, downloaded_file, extracted_file, extracted_file_md5,
hash_type="sha256"):
path = os.path.join(root, extracted_file)
if os.path.exists(path):
root = os.path.abspath(root)
downloaded_file = os.path.abspath(downloaded_file)
extracted_file = os.path.abspath(extracted_file)
if os.path.exists(extracted_file):
with open(os.path.join(root, extracted_file), 'rb') as f:
if validate_file(f, extracted_file_md5, hash_type):
return path
return extracted_file

dataset_tar = download_from_url(url, root=root,
path=os.path.join(root, downloaded_file),
dataset_tar = download_from_url(url, path=os.path.join(root, downloaded_file),
hash_value=url_md5, hash_type=hash_type)
extracted_files = extract_archive(dataset_tar)
assert path == find_match(extracted_file, extracted_files)
return path
assert extracted_file == find_match(extracted_file, extracted_files)
return extracted_file


class RawTextIterableDataset(torch.utils.data.IterableDataset):
Expand Down
24 changes: 19 additions & 5 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.datasets_utils import RawTextIterableDataset
from torchtext.data.datasets_utils import wrap_split_argument
from torchtext.data.datasets_utils import add_docstring_header
from torchtext.data.datasets_utils import find_match
from torchtext.data.datasets_utils import download_extract_validate
import os
import logging

URL = {
'train': "https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz",
Expand All @@ -19,6 +20,16 @@
'test': 2012,
}

_EXTRACTED_FILES = {
'train': 'train.txt',
'test': 'test.txt'
}

_EXTRACTED_FILES_MD5 = {
'train': "2e2f24e90e20fcb910ab2251b5ed8cd0",
'test': "56944df34be553b72a2a634e539a0951"
}


def _create_data_from_iob(data_path, separator):
with open(data_path, encoding="utf-8") as input_file:
Expand All @@ -41,8 +52,11 @@ def _create_data_from_iob(data_path, separator):
@add_docstring_header(num_lines=NUM_LINES)
@wrap_split_argument(('train', 'test'))
def CoNLL2000Chunking(root, split):
dataset_tar = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5')
extracted_files = extract_archive(dataset_tar)
data_filename = find_match(split + ".txt", extracted_files)
# Create a dataset specific subfolder to deal with generic download filenames
root = os.path.join(root, 'conll2000chunking')
path = os.path.join(root, split + ".txt.gz")
data_filename = download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]),
_EXTRACTED_FILES_MD5[split], hash_type="md5")
logging.info('Creating {} data'.format(split))
return RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split],
_create_data_from_iob(data_filename, " "))
4 changes: 4 additions & 0 deletions torchtext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value=
>>> '.data/validation.tar.gz'
"""
if path is not None:
path = os.path.abspath(path)
root = os.path.abspath(root)

def _check_hash(path):
if hash_value:
logging.info('Validating hash {} matches hash of {}'.format(hash_value, path))
Expand Down

0 comments on commit 2764143

Please sign in to comment.