diff --git a/release/nightly_tests/multimodal_inference_benchmarks/document_embedding/ray_data_main.py b/release/nightly_tests/multimodal_inference_benchmarks/document_embedding/ray_data_main.py index 03c0016bf852..e761f1b98516 100644 --- a/release/nightly_tests/multimodal_inference_benchmarks/document_embedding/ray_data_main.py +++ b/release/nightly_tests/multimodal_inference_benchmarks/document_embedding/ray_data_main.py @@ -3,6 +3,7 @@ import pymupdf import ray import ray.data +from ray.data.expressions import download import torch from langchain.text_splitter import RecursiveCharacterTextSplitter from sentence_transformers import SentenceTransformer @@ -42,7 +43,7 @@ def extract_text_from_pdf(row): bs = row.pop("bytes") doc = pymupdf.Document(stream=bs, filetype="pdf") if len(doc) > MAX_PDF_PAGES: - path = row["path"] + path = row["uploaded_pdf_path"] print(f"Skipping PDF {path} because it has {len(doc)} pages") return for page in doc: @@ -50,7 +51,7 @@ def extract_text_from_pdf(row): row["page_number"] = page.number yield row except Exception as e: - path = row["path"] + path = row["uploaded_pdf_path"] print(f"Error extracting text from PDF {path}: {e}") return @@ -83,19 +84,22 @@ def __call__(self, batch): start_time = time.time() -file_paths = ( +( ray.data.read_parquet(INPUT_PATH) .filter(lambda row: row["file_name"].endswith(".pdf")) - .take_all() -) -file_paths = [row["uploaded_pdf_path"] for row in file_paths] -ds = ray.data.read_binary_files(file_paths, include_paths=True) -ds = ds.flat_map(extract_text_from_pdf) -ds = ds.flat_map(chunker) -ds = ds.map_batches( - Embedder, concurrency=NUM_GPU_NODES, num_gpus=1.0, batch_size=EMBEDDING_BATCH_SIZE + .with_column("bytes", download("uploaded_pdf_path")) + .flat_map(extract_text_from_pdf) + .flat_map(chunker) + .map_batches( + Embedder, + concurrency=NUM_GPU_NODES, + num_gpus=1.0, + batch_size=EMBEDDING_BATCH_SIZE, + ) + .select_columns( + ["uploaded_pdf_path", "page_number", "chunk_id", "chunk", "embedding"] + ) + .write_parquet(OUTPUT_PATH) ) -ds = ds.select_columns(["path", "page_number", "chunk_id", "chunk", "embedding"]) -ds.write_parquet(OUTPUT_PATH) print("Runtime:", time.time() - start_time)