Skip to content

Commit

Permalink
feat: Adding Slack and Jira data connector for RAG to SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658175136
  • Loading branch information
speedstorm1 authored and copybara-github committed Jul 31, 2024
1 parent e3fc77f commit d92e7c9
Show file tree
Hide file tree
Showing 6 changed files with 391 additions and 36 deletions.
126 changes: 115 additions & 11 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@
# limitations under the License.
#

from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
RagCorpus,
RagFile,
RagResource,
)

from google.cloud import aiplatform

from vertexai.preview import rag
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
ImportRagFilesConfig,
ImportRagFilesRequest,
ImportRagFilesResponse,
JiraSource as GapicJiraSource,
RagCorpus as GapicRagCorpus,
RagFile as GapicRagFile,
SlackSource as GapicSlackSource,
RagContexts,
RetrieveContextsResponse,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.protobuf import timestamp_pb2


TEST_PROJECT = "test-project"
Expand All @@ -55,10 +56,10 @@
TEST_PROJECT, TEST_REGION
)
)
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
TEST_EMBEDDING_MODEL_CONFIG = rag.EmbeddingModelConfig(
publisher_model="publishers/google/models/textembedding-gecko",
)
TEST_RAG_CORPUS = RagCorpus(
TEST_RAG_CORPUS = rag.RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
Expand Down Expand Up @@ -144,11 +145,114 @@
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
)
TEST_RAG_FILE = RagFile(
TEST_RAG_FILE = rag.RagFile(
name=TEST_RAG_FILE_RESOURCE_NAME,
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
)
# Slack sources
TEST_SLACK_CHANNEL_ID = "123"
TEST_SLACK_CHANNEL_ID_2 = "456"
TEST_SLACK_START_TIME = timestamp_pb2.Timestamp()
TEST_SLACK_START_TIME.GetCurrentTime()
TEST_SLACK_END_TIME = timestamp_pb2.Timestamp()
TEST_SLACK_END_TIME.GetCurrentTime()
TEST_SLACK_API_KEY_SECRET_VERSION = (
"projects/test-project/secrets/test-secret/versions/1"
)
TEST_SLACK_API_KEY_SECRET_VERSION_2 = (
"projects/test-project/secrets/test-secret/versions/2"
)
TEST_SLACK_SOURCE = rag.SlackChannelsSource(
channels=[
rag.SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID,
api_key=TEST_SLACK_API_KEY_SECRET_VERSION,
start_time=TEST_SLACK_START_TIME,
end_time=TEST_SLACK_END_TIME,
),
rag.SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID_2,
api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2,
),
],
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
)
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [
GapicSlackSource.SlackChannels(
channels=[
GapicSlackSource.SlackChannels.SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID,
start_time=TEST_SLACK_START_TIME,
end_time=TEST_SLACK_END_TIME,
),
],
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=TEST_SLACK_API_KEY_SECRET_VERSION
),
),
GapicSlackSource.SlackChannels(
channels=[
GapicSlackSource.SlackChannels.SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID_2,
start_time=None,
end_time=None,
),
],
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=TEST_SLACK_API_KEY_SECRET_VERSION_2
),
),
]
TEST_IMPORT_REQUEST_SLACK_SOURCE = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE,
)
# Jira sources
TEST_JIRA_EMAIL = "test@test.com"
TEST_JIRA_PROJECT = "test-project"
TEST_JIRA_CUSTOM_QUERY = "test-custom-query"
TEST_JIRA_SERVER_URI = "test.atlassian.net"
TEST_JIRA_API_KEY_SECRET_VERSION = (
"projects/test-project/secrets/test-secret/versions/1"
)
TEST_JIRA_SOURCE = rag.JiraSource(
queries=[
rag.JiraQuery(
email=TEST_JIRA_EMAIL,
jira_projects=[TEST_JIRA_PROJECT],
custom_queries=[TEST_JIRA_CUSTOM_QUERY],
api_key=TEST_JIRA_API_KEY_SECRET_VERSION,
server_uri=TEST_JIRA_SERVER_URI,
)
],
)
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig(
rag_file_chunking_config=RagFileChunkingConfig(
chunk_size=TEST_CHUNK_SIZE,
chunk_overlap=TEST_CHUNK_OVERLAP,
)
)
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [
GapicJiraSource.JiraQueries(
custom_queries=[TEST_JIRA_CUSTOM_QUERY],
projects=[TEST_JIRA_PROJECT],
email=TEST_JIRA_EMAIL,
server_uri=TEST_JIRA_SERVER_URI,
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=TEST_JIRA_API_KEY_SECRET_VERSION
),
)
]
TEST_IMPORT_REQUEST_JIRA_SOURCE = ImportRagFilesRequest(
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE,
)

# Retrieval
TEST_QUERY_TEXT = "What happen to the fox and the dog?"
Expand All @@ -162,11 +266,11 @@
]
)
TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS)
TEST_RAG_RESOURCE = RagResource(
TEST_RAG_RESOURCE = rag.RagResource(
rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME,
rag_file_ids=[TEST_RAG_FILE_ID],
)
TEST_RAG_RESOURCE_INVALID_NAME = RagResource(
TEST_RAG_RESOURCE_INVALID_NAME = rag.RagResource(
rag_corpus="213lkj-1/23jkl/",
rag_file_ids=[TEST_RAG_FILE_ID],
)
26 changes: 26 additions & 0 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def import_files_request_eq(returned_request, expected_request):
returned_request.import_rag_files_config.google_drive_source.resource_ids
== expected_request.import_rag_files_config.google_drive_source.resource_ids
)
assert (
returned_request.import_rag_files_config.slack_source.channels
== expected_request.import_rag_files_config.slack_source.channels
)
assert (
returned_request.import_rag_files_config.jira_source.jira_queries
== expected_request.import_rag_files_config.jira_source.jira_queries
)


@pytest.mark.usefixtures("google_auth_mock")
Expand Down Expand Up @@ -421,6 +429,24 @@ def test_prepare_import_files_request_invalid_path(self):
)
e.match("path must be a Google Cloud Storage uri or a Google Drive url")

def test_prepare_import_files_request_slack_source(self):
request = prepare_import_files_request(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
source=tc.TEST_SLACK_SOURCE,
chunk_size=tc.TEST_CHUNK_SIZE,
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
)
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_SLACK_SOURCE)

def test_prepare_import_files_request_jira_source(self):
request = prepare_import_files_request(
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
source=tc.TEST_JIRA_SOURCE,
chunk_size=tc.TEST_CHUNK_SIZE,
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
)
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_JIRA_SOURCE)

def test_set_embedding_model_config_set_both_error(self):
embedding_model_config = rag.EmbeddingModelConfig(
publisher_model="whatever",
Expand Down
12 changes: 12 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
JiraSource,
JiraQuery,
RagCorpus,
RagFile,
RagResource,
SlackChannel,
SlackChannelsSource,
)


Expand All @@ -58,4 +64,10 @@
"Retrieval",
"VertexRagStore",
"RagResource",
"RagFile",
"RagCorpus",
"JiraSource",
"JiraQuery",
"SlackChannel",
"SlackChannelsSource",
)
Loading

0 comments on commit d92e7c9

Please sign in to comment.