diff --git a/docprompt/_pdfium.py b/docprompt/_pdfium.py index d1228d9..f9f5d8e 100644 --- a/docprompt/_pdfium.py +++ b/docprompt/_pdfium.py @@ -10,6 +10,7 @@ import multiprocessing as mp import concurrent.futures as ft from PIL import Image +import tqdm logger = logging.getLogger(__name__) @@ -70,6 +71,33 @@ def _render_parallel_init( ProcObjs = (pdf, kwargs, return_mode, post_process_fn) +def _render_parallel_multi_doc_init( + extra_init, + inputs, + passwords, + may_init_forms, + kwargs, + return_mode="pil", + post_process_fn=None, +): + if extra_init: + extra_init() + + logger.info(f"Initializing data for process {os.getpid()}") + + pdfs_map = {} + + for i, (input, password) in enumerate(zip(inputs, passwords)): + pdf = pdfium.PdfDocument(input, password=password, autoclose=True) + if may_init_forms: + pdf.init_forms() + + pdfs_map[i] = pdf + + global ProcObjsMultiDoc + ProcObjsMultiDoc = (pdfs_map, kwargs, return_mode, post_process_fn) + + def _render_job( i: int, pdf: pdfium.PdfDocument, @@ -98,9 +126,17 @@ def _render_job( return buffer.getvalue() -def _render_parallel_job(i): +def _render_parallel_job(page_indice): global ProcObjs - return _render_job(i, *ProcObjs) + return _render_job(page_indice, *ProcObjs) + + +def _render_parallel_multi_doc_job(pdf_indice, page_indice): + global ProcObjsMultiDoc + + pdf = ProcObjsMultiDoc[0][pdf_indice] + + return pdf_indice, page_indice, _render_job(page_indice, pdf, *ProcObjsMultiDoc[1:]) def rasterize_page_with_pdfium( @@ -133,7 +169,7 @@ def rasterize_pdf_with_pdfium( **kwargs, ) -> List[Union[Image.Image, bytes]]: """ - Rasterizes a page of a PDF document + Rasterizes an entire PDF using PDFium and a pool of workers """ with get_pdfium_document(fp, password=password) as pdf: total_pages = len(pdf) @@ -153,3 +189,69 @@ def rasterize_pdf_with_pdfium( results = executor.map(_render_parallel_job, range(total_pages), chunksize=1) return list(results) + + +def rasterize_pdfs_with_pdfium( + fps: List[Union[PathLike, Path, bytes]], + passwords: Optional[List[str]] = None, + *, + return_mode: Literal["pil", "bytes"] = "pil", + post_process_fn: Optional[Callable[[Image.Image], Image.Image]] = None, + **kwargs, +) -> Dict[int, Dict[int, Union[Image.Image, bytes]]]: + """ + Like 'rasterize_pdf_with_pdfium', but optimized for multiple PDFs by loading all PDF's into the workers memory space + """ + if passwords and len(passwords) != len(fps): + raise ValueError( + "If specifying passwords, must provide one for each PDF. Use None for no password." + ) + + passwords = passwords or [None] * len(fps) + + ctx = mp.get_context("spawn") + + page_counts = [] + total_to_process = 0 + + for fp, password in zip(fps, passwords): + with get_pdfium_document(fp, password) as pdf: + page_counts.append(len(pdf)) + total_to_process += len(pdf) + + initargs = ( + None, + fps, + passwords, + False, + kwargs, + return_mode, + post_process_fn, + ) + + results = {} + + futures = [] + + max_workers = min(mp.cpu_count(), total_to_process) + + with tqdm.tqdm(total=total_to_process) as pbar: + with ft.ProcessPoolExecutor( + max_workers=max_workers, + initializer=_render_parallel_multi_doc_init, + initargs=initargs, + mp_context=ctx, + ) as executor: + for i, page_count in enumerate(page_counts): + for j in range(page_count): + futures.append( + executor.submit(_render_parallel_multi_doc_job, i, j) + ) + + for future in ft.as_completed(futures): + pdf_indice, page_indice, result = future.result() + + results.setdefault(pdf_indice, {})[page_indice + 1] = result + pbar.update(1) + + return results diff --git a/docprompt/schema/pipeline.py b/docprompt/schema/pipeline.py index ed74a49..7a01f32 100644 --- a/docprompt/schema/pipeline.py +++ b/docprompt/schema/pipeline.py @@ -216,6 +216,15 @@ def rasterize( return list(images.values()) + def propagate_cache(self, name: str, rasters: Dict[int, Union[bytes, Image.Image]]): + """ + Should be one-indexed + """ + for page_number, raster in rasters.items(): + page_node = self.owner.page_nodes[page_number - 1] + + page_node._raster_cache[name] = raster + class PageNode(BaseModel, Generic[PageNodeMetadata]): """ diff --git a/tests/fixtures.py b/tests/fixtures.py index 5b1734a..085656c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -8,4 +8,10 @@ file_hash="121ffed4336e6129e97ee3c4cb747864", ocr_name="1_ocr.json", ), + PdfFixture( + name="2.pdf", + page_count=23, + file_hash="bd2fa4f101b305e4001acf9137ce78cf", + ocr_name="1_ocr.json", + ), ] diff --git a/tests/fixtures/2.pdf b/tests/fixtures/2.pdf new file mode 100644 index 0000000..371873b Binary files /dev/null and b/tests/fixtures/2.pdf differ diff --git a/tests/test_documentnode.py b/tests/test_documentnode.py index b03ec99..bb284be 100644 --- a/tests/test_documentnode.py +++ b/tests/test_documentnode.py @@ -1,5 +1,6 @@ import pickle from docprompt import load_document, DocumentNode +from docprompt._pdfium import rasterize_pdfs_with_pdfium from .fixtures import PDF_FIXTURES from PIL import Image @@ -53,6 +54,33 @@ def test_rasterize_via_document_node(): ) +def test_multi_rasterize(): + document_1 = load_document(PDF_FIXTURES[0].get_full_path()) + document_2 = load_document(PDF_FIXTURES[1].get_full_path()) + + node_1 = DocumentNode.from_document(document_1) + node_2 = DocumentNode.from_document(document_2) + + results = rasterize_pdfs_with_pdfium([document_1.file_bytes, document_2.file_bytes]) + + node_1.rasterizer.propagate_cache("default", results[0]) + node_2.rasterizer.propagate_cache("default", results[1]) + + assert all( + ( + "default" in page_node.rasterizer.raster_cache + for page_node in node_1.page_nodes + ) + ) + + assert all( + ( + "default" in page_node.rasterizer.raster_cache + for page_node in node_2.page_nodes + ) + ) + + def test__pickling_drops_cache(): document = load_document(PDF_FIXTURES[0].get_full_path())