Skip to content

Commit

Permalink
add missing decompression functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxDall committed Nov 8, 2024
1 parent bca2308 commit 74dfc8e
Showing 1 changed file with 48 additions and 31 deletions.
79 changes: 48 additions & 31 deletions src/fundus/scraping/url.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import bz2
import gzip
import itertools
import lzma
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from typing import Callable, ClassVar, Dict, Iterable, Iterator, List, Optional

Expand All @@ -21,48 +22,64 @@
logger = create_logger(__name__)


class CompressionFormats(Enum):
GZIP = 1
BZ2 = 2
ZIP = 3
LZMA = 4
class CompressionFormat:
def __init__(
self, name: str, decompression: Optional[Callable[[bytes], bytes]] = None, *, byte_mask: Optional[bytes] = None
) -> None:
self.name = name
self.decompression = decompression
self.byte_mask = byte_mask

def match(self, compressed_content: bytes) -> bool:
if self.byte_mask:
return compressed_content.startswith(self.byte_mask)
return False

def __call__(self, compressed_content: bytes) -> bytes:
if self.decompression is None:
raise NotImplementedError(f"Decompression not implemented for {self.name!r}")
return self.decompression(compressed_content)

def __repr__(self):
if self.decompression is None:
return f"{self.name} -- Not implemented"
return self.name


class CompressionFormats:
GZIP = CompressionFormat("gzip", gzip.decompress, byte_mask=b"\x1f\x8b")
BZ2 = CompressionFormat("bz2", bz2.decompress, byte_mask=b"\x42\x5a")
ZIP = CompressionFormat("zip", byte_mask=b"PK\x03\x04")
LZMA = CompressionFormat("lzma", lzma.decompress, byte_mask=b"\x28\xb5\x2f\xfd")

@classmethod
def iter_formats(cls) -> Iterator[CompressionFormat]:
for obj in cls.__dict__.values():
if isinstance(obj, CompressionFormat):
yield obj

@classmethod
def identify(cls, compressed_content: bytes) -> Optional[CompressionFormat]:
for compression_format in cls.iter_formats():
if compression_format.match(compressed_content):
return compression_format
return None


class _ArchiveDecompressor:
def __init__(self):
self.archive_mapping: Dict[str, Callable[[bytes], bytes]] = {
"application/octet-stream": self._decompress_octet_stream,
"application/x-gzip": self._decompress_gzip,
"gzip": self._decompress_gzip,
"application/x-gzip": CompressionFormats.GZIP,
"gzip": CompressionFormats.GZIP,
}

@staticmethod
def identify_compression_format(compressed_content: bytes) -> Optional[CompressionFormats]:
if compressed_content.startswith(b"\x1f\x8b"):
return CompressionFormats.GZIP
elif compressed_content.startswith(b"\x42\x5a"):
return CompressionFormats.BZ2
elif compressed_content.startswith(b"PK\x03\x04"):
return CompressionFormats.ZIP
elif compressed_content.startswith(b"\x28\xb5\x2f\xfd"):
return CompressionFormats.LZMA
return None

def _decompress_octet_stream(self, compressed_content: bytes) -> bytes:
if (compression_format := self.identify_compression_format(compressed_content)) is None:
if (compression_format := CompressionFormats.identify(compressed_content)) is None:
logger.debug(f"Could not identify compression format")
raise NotImplementedError

if compression_format == CompressionFormats.GZIP:
return self._decompress_gzip(compressed_content)
else:
logger.debug(f"Decompression not implemented for {compression_format.name!r} format")
raise NotImplementedError

@staticmethod
def _decompress_gzip(compressed_content: bytes) -> bytes:
decompressed_content = gzip.decompress(compressed_content)
return decompressed_content
return compression_format(compressed_content)

def decompress(self, content: bytes, file_format: "str") -> bytes:
decompress_function = self.archive_mapping[file_format]
Expand Down

0 comments on commit 74dfc8e

Please sign in to comment.