Skip to content

Commit

Permalink
feat: Add advanced PDF parsing option for RAG file import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663391146
  • Loading branch information
speedstorm1 authored and copybara-github committed Aug 15, 2024
1 parent d03468a commit 6e1dc06
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
RagFileParsingConfig,
ImportRagFilesConfig,
ImportRagFilesRequest,
ImportRagFilesResponse,
Expand Down Expand Up @@ -93,6 +94,7 @@
# GCS
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
Expand All @@ -112,18 +114,36 @@
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = (
False
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig()
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
GoogleDriveSource.ResourceId(
resource_id=TEST_DRIVE_FOLDER_ID,
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
)
]
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = (
True
)
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER,
)
TEST_IMPORT_REQUEST_DRIVE_FOLDER_PARSING = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING,
)
# Google Drive files
TEST_DRIVE_FILE_ID = "456"
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
),
rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False),
)
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800

Expand Down
15 changes: 15 additions & 0 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def import_files_request_eq(returned_request, expected_request):
returned_request.import_rag_files_config.jira_source.jira_queries
== expected_request.import_rag_files_config.jira_source.jira_queries
)
assert (
returned_request.import_rag_files_config.rag_file_parsing_config
== expected_request.import_rag_files_config.rag_file_parsing_config
)


@pytest.mark.usefixtures("google_auth_mock")
Expand Down Expand Up @@ -396,6 +400,17 @@ def test_prepare_import_files_request_drive_folders(self, path):
)
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FOLDER)

@pytest.mark.parametrize("path", [tc.TEST_DRIVE_FOLDER, tc.TEST_DRIVE_FOLDER_2])
def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path):
request = prepare_import_files_request(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
paths=[path],
chunk_size=tc.TEST_CHUNK_SIZE,
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
use_advanced_pdf_parsing=True,
)
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FOLDER_PARSING)

def test_prepare_import_files_request_drive_files(self):
paths = [tc.TEST_DRIVE_FILE]
request = prepare_import_files_request(
Expand Down
8 changes: 8 additions & 0 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def import_files(
chunk_overlap: int = 200,
timeout: int = 600,
max_embedding_requests_per_min: int = 1000,
use_advanced_pdf_parsing: Optional[bool] = False,
) -> ImportRagFilesResponse:
"""
Import files to an existing RagCorpus, wait until completion.
Expand Down Expand Up @@ -364,6 +365,8 @@ def import_files(
here. If unspecified, a default value of 1,000
QPM would be used.
timeout: Default is 600 seconds.
use_advanced_pdf_parsing: Whether to use advanced PDF
parsing on uploaded files.
Returns:
ImportRagFilesResponse.
"""
Expand All @@ -379,6 +382,7 @@ def import_files(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
max_embedding_requests_per_min=max_embedding_requests_per_min,
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
)
client = _gapic_utils.create_rag_data_service_client()
try:
Expand All @@ -396,6 +400,7 @@ async def import_files_async(
chunk_size: int = 1024,
chunk_overlap: int = 200,
max_embedding_requests_per_min: int = 1000,
use_advanced_pdf_parsing: Optional[bool] = False,
) -> operation_async.AsyncOperation:
"""
Import files to an existing RagCorpus asynchronously.
Expand Down Expand Up @@ -479,6 +484,8 @@ async def import_files_async(
page on the project to set an appropriate value
here. If unspecified, a default value of 1,000
QPM would be used.
use_advanced_pdf_parsing: Whether to use advanced PDF
parsing on uploaded files.
Returns:
operation_async.AsyncOperation.
"""
Expand All @@ -494,6 +501,7 @@ async def import_files_async(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
max_embedding_requests_per_min=max_embedding_requests_per_min,
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
)
async_client = _gapic_utils.create_rag_data_service_async_client()
try:
Expand Down
6 changes: 6 additions & 0 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ImportRagFilesConfig,
ImportRagFilesRequest,
RagFileChunkingConfig,
RagFileParsingConfig,
RagCorpus as GapicRagCorpus,
RagFile as GapicRagFile,
SlackSource as GapicSlackSource,
Expand Down Expand Up @@ -217,19 +218,24 @@ def prepare_import_files_request(
chunk_size: int = 1024,
chunk_overlap: int = 200,
max_embedding_requests_per_min: int = 1000,
use_advanced_pdf_parsing: bool = False,
) -> ImportRagFilesRequest:
if len(corpus_name.split("/")) != 6:
raise ValueError(
"corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`"
)

rag_file_parsing_config = RagFileParsingConfig(
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
)
rag_file_chunking_config = RagFileChunkingConfig(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
import_rag_files_config = ImportRagFilesConfig(
rag_file_chunking_config=rag_file_chunking_config,
max_embedding_requests_per_min=max_embedding_requests_per_min,
rag_file_parsing_config=rag_file_parsing_config,
)

if source is not None:
Expand Down

0 comments on commit 6e1dc06

Please sign in to comment.