Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Functionality to accept compressed files as input to predict when using a Predictor #5299

Closed
wants to merge 9 commits into from
32 changes: 26 additions & 6 deletions allennlp/commands/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common import logging as common_logging
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.file_utils import cached_path, open_compressed
from allennlp.common.util import lazy_groups_of
from allennlp.data.dataset_readers import MultiTaskDatasetReader
from allennlp.models.archival import load_archive
Expand Down Expand Up @@ -73,6 +73,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
"flag is set.",
)

subparser.add_argument(
"--compression-type",
type=str,
choices=["gz", "bz2", "lzma"],
default=None,
help="Indicates the compressed format of the input file.",
)

subparser.add_argument(
"--multitask-head",
type=str,
Expand Down Expand Up @@ -152,6 +160,7 @@ def __init__(
batch_size: int,
print_to_console: bool,
has_dataset_reader: bool,
compression_type: str = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type should be Optional[str].

multitask_head: Optional[str] = None,
) -> None:
self._predictor = predictor
Expand All @@ -160,7 +169,7 @@ def __init__(
self._batch_size = batch_size
self._print_to_console = print_to_console
self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader

self.compression_type = compression_type
self._multitask_head = multitask_head
if self._multitask_head is not None:
if self._dataset_reader is None:
Expand Down Expand Up @@ -212,10 +221,21 @@ def _get_json_data(self) -> Iterator[JsonDict]:
yield self._predictor.load_line(line)
else:
input_file = cached_path(self._input_file)
with open(input_file, "r") as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
try:
with open_compressed(input_file) as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
except OSError:
if self.compression_type:
with open_compressed(input_file, self.compression_type) as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
else:
print(
"Automatic detection of compression type failed, please specify the compression type argument"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic needs to be the other way around. If the compression type is specified, we have to always respect it. If it's not specified, we autodetect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, will incorporate it


def _get_instance_data(self) -> Iterator[Instance]:
if self._input_file == "-":
Expand Down
23 changes: 15 additions & 8 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,20 +1085,27 @@ def get_file_extension(path: str, dot=True, lower: bool = True):


def open_compressed(
filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs
filename: Union[str, PathLike],
compression_type: str = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the last parameter, so we don't break existing usage of positional arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the type of this should be Optional[None].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the make typecheck commands gave me incompatibility error on changing it to Optional[None] so I kept it as Optional[str], will that be fine?

mode: str = "rt",
encoding: Optional[str] = "UTF-8",
**kwargs,
):
if not isinstance(filename, str):
filename = str(filename)
open_fn: Callable = open

if filename.endswith(".gz"):
import gzip

open_fn = gzip.open
elif filename.endswith(".bz2"):
import bz2
compression_modules = {"gz": "gzip", "bz2": "bz2", "lzma": "lzma"}
if not compression_type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I pass in an empty string for compression_type, we will go down this path. I don't think that's what we want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I simply consider an empty string for compression_type to be equivalent to it being 'None' ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would be surprising to a user of the library. You can just compare to None:

Suggested change
if not compression_type:
if compression_type is None:

for extension in compression_modules:
if filename.endswith(extension):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use os.path.splitext() here to make that detection? I don't want a file named info.fogbugz to show up as a gzip file.

module = __import__(compression_modules[extension])
open_fn = module.open
break
else:
module = __import__(compression_modules[extension])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think extension is undefined here? Or am I blind and can't see it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that was an error at my end, will fix it

open_fn = module.open

open_fn = bz2.open
return open_fn(cached_path(filename), mode=mode, encoding=encoding, **kwargs)


Expand Down
2 changes: 1 addition & 1 deletion tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_open_compressed(self):
with open_compressed(uncompressed_file) as f:
uncompressed_lines = [line.strip() for line in f]

for suffix in ["bz2", "gz"]:
for suffix in ["bz2", "gz", "lzma"]:
compressed_file = f"{uncompressed_file}.{suffix}"
with open_compressed(compressed_file) as f:
compressed_lines = [line.strip() for line in f]
Expand Down