diff --git a/allennlp/commands/predict.py b/allennlp/commands/predict.py index 802eb6ba2e6..97fd788f46c 100644 --- a/allennlp/commands/predict.py +++ b/allennlp/commands/predict.py @@ -12,7 +12,7 @@ from allennlp.commands.subcommand import Subcommand from allennlp.common import logging as common_logging from allennlp.common.checks import check_for_gpu, ConfigurationError -from allennlp.common.file_utils import cached_path +from allennlp.common.file_utils import cached_path, open_compressed from allennlp.common.util import lazy_groups_of from allennlp.data.dataset_readers import MultiTaskDatasetReader from allennlp.models.archival import load_archive @@ -71,6 +71,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument "flag is set.", ) + subparser.add_argument( + "--compression-type", + type=str, + choices=["gz", "bz2", "lzma"], + default=None, + help="Indicates the compressed format of the input file.", + ) + subparser.add_argument( "--multitask-head", type=str, @@ -150,6 +158,7 @@ def __init__( batch_size: int, print_to_console: bool, has_dataset_reader: bool, + compression_type: str = None, multitask_head: Optional[str] = None, ) -> None: self._predictor = predictor @@ -158,7 +167,7 @@ def __init__( self._batch_size = batch_size self._print_to_console = print_to_console self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader - + self.compression_type = compression_type self._multitask_head = multitask_head if self._multitask_head is not None: if self._dataset_reader is None: @@ -210,10 +219,25 @@ def _get_json_data(self) -> Iterator[JsonDict]: yield self._predictor.load_line(line) else: input_file = cached_path(self._input_file) - with open(input_file, "r") as file_input: - for line in file_input: - if not line.isspace(): - yield self._predictor.load_line(line) + if self.compression_type is None: + try: + with open_compressed(input_file) as file_input: + for line in file_input: + if not line.isspace(): + yield self._predictor.load_line(line) + except OSError: + print( + "Automatic detection failed, please specify the compression type argument." + ) + + else: + try: + with open_compressed(input_file, compression_type=self.compression_type) as file_input: + for line in file_input: + if not line.isspace(): + yield self._predictor.load_line(line) + except OSError: + print("please specify the correct compression type argument.") def _get_instance_data(self) -> Iterator[Instance]: if self._input_file == "-": diff --git a/allennlp/common/file_utils.py b/allennlp/common/file_utils.py index 319ee47d107..052ad24c62b 100644 --- a/allennlp/common/file_utils.py +++ b/allennlp/common/file_utils.py @@ -444,20 +444,27 @@ def get_file_extension(path: str, dot=True, lower: bool = True): def open_compressed( - filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs + filename: Union[str, PathLike], + mode: str = "rt", + encoding: Optional[str] = "UTF-8", + compression_type: Optional[str] = None, + **kwargs, ): if not isinstance(filename, str): filename = str(filename) open_fn: Callable = open - if filename.endswith(".gz"): - import gzip - - open_fn = gzip.open - elif filename.endswith(".bz2"): - import bz2 + compression_modules = {"gz": "gzip", "bz2": "bz2", "lzma": "lzma"} + if compression_type in compression_modules: + module = __import__(compression_modules[compression_type]) + open_fn = module.open + else: + for extension in compression_modules: + if filename.endswith(extension): + module = __import__(compression_modules[extension]) + open_fn = module.open + break - open_fn = bz2.open return open_fn(cached_path(filename), mode=mode, encoding=encoding, **kwargs) diff --git a/tests/common/file_utils_test.py b/tests/common/file_utils_test.py index 584c338bd03..de8a9c4f75d 100644 --- a/tests/common/file_utils_test.py +++ b/tests/common/file_utils_test.py @@ -226,7 +226,7 @@ def test_open_compressed(self): with open_compressed(uncompressed_file) as f: uncompressed_lines = [line.strip() for line in f] - for suffix in ["bz2", "gz"]: + for suffix in ["bz2", "gz", "lzma"]: compressed_file = f"{uncompressed_file}.{suffix}" with open_compressed(compressed_file) as f: compressed_lines = [line.strip() for line in f]