-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Functionality to accept compressed files as input to predict when using a Predictor #5299
Changes from 4 commits
0a38254
316b15e
47349d4
79d7e96
0ac8f67
85bddc7
ef821b9
026e42a
23c50c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from allennlp.commands.subcommand import Subcommand | ||
from allennlp.common import logging as common_logging | ||
from allennlp.common.checks import check_for_gpu, ConfigurationError | ||
from allennlp.common.file_utils import cached_path | ||
from allennlp.common.file_utils import cached_path, open_compressed | ||
from allennlp.common.util import lazy_groups_of | ||
from allennlp.data.dataset_readers import MultiTaskDatasetReader | ||
from allennlp.models.archival import load_archive | ||
|
@@ -73,6 +73,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument | |
"flag is set.", | ||
) | ||
|
||
subparser.add_argument( | ||
"--compression-type", | ||
type=str, | ||
choices=["gz", "bz2", "lzma"], | ||
default=None, | ||
help="Indicates the compressed format of the input file.", | ||
) | ||
|
||
subparser.add_argument( | ||
"--multitask-head", | ||
type=str, | ||
|
@@ -152,6 +160,7 @@ def __init__( | |
batch_size: int, | ||
print_to_console: bool, | ||
has_dataset_reader: bool, | ||
compression_type: str = None, | ||
multitask_head: Optional[str] = None, | ||
) -> None: | ||
self._predictor = predictor | ||
|
@@ -160,7 +169,7 @@ def __init__( | |
self._batch_size = batch_size | ||
self._print_to_console = print_to_console | ||
self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader | ||
|
||
self.compression_type = compression_type | ||
self._multitask_head = multitask_head | ||
if self._multitask_head is not None: | ||
if self._dataset_reader is None: | ||
|
@@ -212,10 +221,21 @@ def _get_json_data(self) -> Iterator[JsonDict]: | |
yield self._predictor.load_line(line) | ||
else: | ||
input_file = cached_path(self._input_file) | ||
with open(input_file, "r") as file_input: | ||
for line in file_input: | ||
if not line.isspace(): | ||
yield self._predictor.load_line(line) | ||
try: | ||
with open_compressed(input_file) as file_input: | ||
for line in file_input: | ||
if not line.isspace(): | ||
yield self._predictor.load_line(line) | ||
except OSError: | ||
if self.compression_type: | ||
with open_compressed(input_file, self.compression_type) as file_input: | ||
for line in file_input: | ||
if not line.isspace(): | ||
yield self._predictor.load_line(line) | ||
else: | ||
print( | ||
"Automatic detection of compression type failed, please specify the compression type argument" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic needs to be the other way around. If the compression type is specified, we have to always respect it. If it's not specified, we autodetect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense, will incorporate it |
||
|
||
def _get_instance_data(self) -> Iterator[Instance]: | ||
if self._input_file == "-": | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1085,20 +1085,27 @@ def get_file_extension(path: str, dot=True, lower: bool = True): | |||||
|
||||||
|
||||||
def open_compressed( | ||||||
filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs | ||||||
filename: Union[str, PathLike], | ||||||
compression_type: str = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be the last parameter, so we don't break existing usage of positional arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, the type of this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||||||
mode: str = "rt", | ||||||
encoding: Optional[str] = "UTF-8", | ||||||
**kwargs, | ||||||
): | ||||||
if not isinstance(filename, str): | ||||||
filename = str(filename) | ||||||
open_fn: Callable = open | ||||||
|
||||||
if filename.endswith(".gz"): | ||||||
import gzip | ||||||
|
||||||
open_fn = gzip.open | ||||||
elif filename.endswith(".bz2"): | ||||||
import bz2 | ||||||
compression_modules = {"gz": "gzip", "bz2": "bz2", "lzma": "lzma"} | ||||||
if not compression_type: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I pass in an empty string for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can I simply consider an empty string for compression_type to be equivalent to it being 'None' ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that would be surprising to a user of the library. You can just compare to
Suggested change
|
||||||
for extension in compression_modules: | ||||||
if filename.endswith(extension): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use |
||||||
module = __import__(compression_modules[extension]) | ||||||
open_fn = module.open | ||||||
break | ||||||
else: | ||||||
module = __import__(compression_modules[extension]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that was an error at my end, will fix it |
||||||
open_fn = module.open | ||||||
|
||||||
open_fn = bz2.open | ||||||
return open_fn(cached_path(filename), mode=mode, encoding=encoding, **kwargs) | ||||||
|
||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type should be
Optional[str]
.