Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Functionality to accept compressed files as input to predict when using a Predictor #5299

Closed
wants to merge 9 commits into from
36 changes: 30 additions & 6 deletions allennlp/commands/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common import logging as common_logging
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.file_utils import cached_path, open_compressed
from allennlp.common.util import lazy_groups_of
from allennlp.data.dataset_readers import MultiTaskDatasetReader
from allennlp.models.archival import load_archive
Expand Down Expand Up @@ -71,6 +71,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
"flag is set.",
)

subparser.add_argument(
"--compression-type",
type=str,
choices=["gz", "bz2", "lzma"],
default=None,
help="Indicates the compressed format of the input file.",
)

subparser.add_argument(
"--multitask-head",
type=str,
Expand Down Expand Up @@ -150,6 +158,7 @@ def __init__(
batch_size: int,
print_to_console: bool,
has_dataset_reader: bool,
compression_type: str = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type should be Optional[str].

multitask_head: Optional[str] = None,
) -> None:
self._predictor = predictor
Expand All @@ -158,7 +167,7 @@ def __init__(
self._batch_size = batch_size
self._print_to_console = print_to_console
self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader

self.compression_type = compression_type
self._multitask_head = multitask_head
if self._multitask_head is not None:
if self._dataset_reader is None:
Expand Down Expand Up @@ -210,10 +219,25 @@ def _get_json_data(self) -> Iterator[JsonDict]:
yield self._predictor.load_line(line)
else:
input_file = cached_path(self._input_file)
with open(input_file, "r") as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
if self.compression_type is None:
try:
with open_compressed(input_file) as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
except OSError:
print(
"Automatic detection failed, please specify the compression type argument."
)

else:
try:
with open_compressed(input_file, compression_type=self.compression_type) as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
except OSError:
print("please specify the correct compression type argument.")
Comment on lines +239 to +240
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why OSError? I don't think you have to catch this exception at all. If it fails, it fails. The only thing we might want to do is make sure that open_compressed() throws exceptions that are understandable.


def _get_instance_data(self) -> Iterator[Instance]:
if self._input_file == "-":
Expand Down
23 changes: 15 additions & 8 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,20 +444,27 @@ def get_file_extension(path: str, dot=True, lower: bool = True):


def open_compressed(
filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs
filename: Union[str, PathLike],
mode: str = "rt",
encoding: Optional[str] = "UTF-8",
compression_type: Optional[str] = None,
**kwargs,
):
if not isinstance(filename, str):
filename = str(filename)
open_fn: Callable = open

if filename.endswith(".gz"):
import gzip

open_fn = gzip.open
elif filename.endswith(".bz2"):
import bz2
compression_modules = {"gz": "gzip", "bz2": "bz2", "lzma": "lzma"}
if compression_type in compression_modules:
module = __import__(compression_modules[compression_type])
open_fn = module.open
else:
for extension in compression_modules:
if filename.endswith(extension):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use os.path.splitext() here to make that detection? I don't want a file named info.fogbugz to show up as a gzip file.

module = __import__(compression_modules[extension])
open_fn = module.open
break

open_fn = bz2.open
return open_fn(cached_path(filename), mode=mode, encoding=encoding, **kwargs)


Expand Down
2 changes: 1 addition & 1 deletion tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_open_compressed(self):
with open_compressed(uncompressed_file) as f:
uncompressed_lines = [line.strip() for line in f]

for suffix in ["bz2", "gz"]:
for suffix in ["bz2", "gz", "lzma"]:
compressed_file = f"{uncompressed_file}.{suffix}"
with open_compressed(compressed_file) as f:
compressed_lines = [line.strip() for line in f]
Expand Down