Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pymupdf
import ray
import ray.data
from ray.data.expressions import download
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make the code even more idiomatic and potentially more performant through expression optimization, consider importing col here. It can be used to replace the lambda in the filter operation below.

Suggested change
from ray.data.expressions import download
from ray.data.expressions import col, download

import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -42,15 +43,15 @@ def extract_text_from_pdf(row):
bs = row.pop("bytes")
doc = pymupdf.Document(stream=bs, filetype="pdf")
if len(doc) > MAX_PDF_PAGES:
path = row["path"]
path = row["uploaded_pdf_path"]
print(f"Skipping PDF {path} because it has {len(doc)} pages")
return
for page in doc:
row["page_text"] = page.get_text()
row["page_number"] = page.number
yield row
except Exception as e:
path = row["path"]
path = row["uploaded_pdf_path"]
print(f"Error extracting text from PDF {path}: {e}")
return

Expand Down Expand Up @@ -83,19 +84,22 @@ def __call__(self, batch):

start_time = time.time()

file_paths = (
(
ray.data.read_parquet(INPUT_PATH)
.filter(lambda row: row["file_name"].endswith(".pdf"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a col expression is more idiomatic in Ray Data than using a lambda for simple filtering operations. This declarative style can also allow Ray Data to perform more optimizations on the execution plan.

Suggested change
.filter(lambda row: row["file_name"].endswith(".pdf"))
.filter(col("file_name").str.endswith(".pdf"))

.take_all()
)
file_paths = [row["uploaded_pdf_path"] for row in file_paths]
ds = ray.data.read_binary_files(file_paths, include_paths=True)
ds = ds.flat_map(extract_text_from_pdf)
ds = ds.flat_map(chunker)
ds = ds.map_batches(
Embedder, concurrency=NUM_GPU_NODES, num_gpus=1.0, batch_size=EMBEDDING_BATCH_SIZE
.with_column("bytes", download("uploaded_pdf_path"))
.flat_map(extract_text_from_pdf)
.flat_map(chunker)
.map_batches(
Embedder,
concurrency=NUM_GPU_NODES,
num_gpus=1.0,
batch_size=EMBEDDING_BATCH_SIZE,
)
.select_columns(
["uploaded_pdf_path", "page_number", "chunk_id", "chunk", "embedding"]
)
.write_parquet(OUTPUT_PATH)
)
ds = ds.select_columns(["path", "page_number", "chunk_id", "chunk", "embedding"])
ds.write_parquet(OUTPUT_PATH)

print("Runtime:", time.time() - start_time)