Skip to content

Commit

Permalink
Add fast concurrent raster
Browse files Browse the repository at this point in the history
  • Loading branch information
PSU3D0 committed May 16, 2024
1 parent ebab120 commit 52cbe84
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 3 deletions.
108 changes: 105 additions & 3 deletions docprompt/_pdfium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import multiprocessing as mp
import concurrent.futures as ft
from PIL import Image
import tqdm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +71,33 @@ def _render_parallel_init(
ProcObjs = (pdf, kwargs, return_mode, post_process_fn)


def _render_parallel_multi_doc_init(
extra_init,
inputs,
passwords,
may_init_forms,
kwargs,
return_mode="pil",
post_process_fn=None,
):
if extra_init:
extra_init()

logger.info(f"Initializing data for process {os.getpid()}")

pdfs_map = {}

for i, (input, password) in enumerate(zip(inputs, passwords)):
pdf = pdfium.PdfDocument(input, password=password, autoclose=True)
if may_init_forms:
pdf.init_forms()

pdfs_map[i] = pdf

global ProcObjsMultiDoc
ProcObjsMultiDoc = (pdfs_map, kwargs, return_mode, post_process_fn)


def _render_job(
i: int,
pdf: pdfium.PdfDocument,
Expand Down Expand Up @@ -98,9 +126,17 @@ def _render_job(
return buffer.getvalue()


def _render_parallel_job(i):
def _render_parallel_job(page_indice):
global ProcObjs
return _render_job(i, *ProcObjs)
return _render_job(page_indice, *ProcObjs)


def _render_parallel_multi_doc_job(pdf_indice, page_indice):
global ProcObjsMultiDoc

pdf = ProcObjsMultiDoc[0][pdf_indice]

return pdf_indice, page_indice, _render_job(page_indice, pdf, *ProcObjsMultiDoc[1:])


def rasterize_page_with_pdfium(
Expand Down Expand Up @@ -133,7 +169,7 @@ def rasterize_pdf_with_pdfium(
**kwargs,
) -> List[Union[Image.Image, bytes]]:
"""
Rasterizes a page of a PDF document
Rasterizes an entire PDF using PDFium and a pool of workers
"""
with get_pdfium_document(fp, password=password) as pdf:
total_pages = len(pdf)
Expand All @@ -153,3 +189,69 @@ def rasterize_pdf_with_pdfium(
results = executor.map(_render_parallel_job, range(total_pages), chunksize=1)

return list(results)


def rasterize_pdfs_with_pdfium(
fps: List[Union[PathLike, Path, bytes]],
passwords: Optional[List[str]] = None,
*,
return_mode: Literal["pil", "bytes"] = "pil",
post_process_fn: Optional[Callable[[Image.Image], Image.Image]] = None,
**kwargs,
) -> Dict[int, Dict[int, Union[Image.Image, bytes]]]:
"""
Like 'rasterize_pdf_with_pdfium', but optimized for multiple PDFs by loading all PDF's into the workers memory space
"""
if passwords and len(passwords) != len(fps):
raise ValueError(
"If specifying passwords, must provide one for each PDF. Use None for no password."
)

passwords = passwords or [None] * len(fps)

ctx = mp.get_context("spawn")

page_counts = []
total_to_process = 0

for fp, password in zip(fps, passwords):
with get_pdfium_document(fp, password) as pdf:
page_counts.append(len(pdf))
total_to_process += len(pdf)

initargs = (
None,
fps,
passwords,
False,
kwargs,
return_mode,
post_process_fn,
)

results = {}

futures = []

max_workers = min(mp.cpu_count(), total_to_process)

with tqdm.tqdm(total=total_to_process) as pbar:
with ft.ProcessPoolExecutor(
max_workers=max_workers,
initializer=_render_parallel_multi_doc_init,
initargs=initargs,
mp_context=ctx,
) as executor:
for i, page_count in enumerate(page_counts):
for j in range(page_count):
futures.append(
executor.submit(_render_parallel_multi_doc_job, i, j)
)

for future in ft.as_completed(futures):
pdf_indice, page_indice, result = future.result()

results.setdefault(pdf_indice, {})[page_indice + 1] = result
pbar.update(1)

return results
9 changes: 9 additions & 0 deletions docprompt/schema/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ def rasterize(

return list(images.values())

def propagate_cache(self, name: str, rasters: Dict[int, Union[bytes, Image.Image]]):
"""
Should be one-indexed
"""
for page_number, raster in rasters.items():
page_node = self.owner.page_nodes[page_number - 1]

page_node._raster_cache[name] = raster


class PageNode(BaseModel, Generic[PageNodeMetadata]):
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@
file_hash="121ffed4336e6129e97ee3c4cb747864",
ocr_name="1_ocr.json",
),
PdfFixture(
name="2.pdf",
page_count=23,
file_hash="bd2fa4f101b305e4001acf9137ce78cf",
ocr_name="1_ocr.json",
),
]
Binary file added tests/fixtures/2.pdf
Binary file not shown.
28 changes: 28 additions & 0 deletions tests/test_documentnode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
from docprompt import load_document, DocumentNode
from docprompt._pdfium import rasterize_pdfs_with_pdfium
from .fixtures import PDF_FIXTURES

from PIL import Image
Expand Down Expand Up @@ -53,6 +54,33 @@ def test_rasterize_via_document_node():
)


def test_multi_rasterize():
document_1 = load_document(PDF_FIXTURES[0].get_full_path())
document_2 = load_document(PDF_FIXTURES[1].get_full_path())

node_1 = DocumentNode.from_document(document_1)
node_2 = DocumentNode.from_document(document_2)

results = rasterize_pdfs_with_pdfium([document_1.file_bytes, document_2.file_bytes])

node_1.rasterizer.propagate_cache("default", results[0])
node_2.rasterizer.propagate_cache("default", results[1])

assert all(
(
"default" in page_node.rasterizer.raster_cache
for page_node in node_1.page_nodes
)
)

assert all(
(
"default" in page_node.rasterizer.raster_cache
for page_node in node_2.page_nodes
)
)


def test__pickling_drops_cache():
document = load_document(PDF_FIXTURES[0].get_full_path())

Expand Down

0 comments on commit 52cbe84

Please sign in to comment.