From 43e4fa6a0c99ddb7fb162c6c27ede396ee72a9c9 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Mon, 4 Nov 2024 09:43:16 -0700 Subject: [PATCH 01/15] RAG Pipeline initial commit --- lambda/__init__.py | 0 lambda/repository/lambda_functions.py | 126 ++++++- .../repository/pipeline_ingest_documents.py | 169 ++++++++++ lambda/repository/state_machine/__init__.py | 0 .../state_machine/list_modified_objects.py | 169 ++++++++++ .../pipeline_ingest_documents.py | 71 ++++ lambda/utilities/common_functions.py | 56 +++- lambda/utilities/create_env_variables.py | 35 ++ lambda/utilities/file_processing.py | 16 +- lambda/utilities/validation.py | 158 +++++++++ lib/core/iam/rag.json | 92 +++-- lib/rag/index.ts | 94 ++++-- lib/rag/state_machine/constants.ts | 23 ++ lib/rag/state_machine/ingest-pipeline.ts | 315 ++++++++++++++++++ lib/schema.ts | 39 ++- lib/serve/index.ts | 5 +- lib/serve/rest-api/Dockerfile | 2 +- package-lock.json | 64 ++-- package.json | 4 +- 19 files changed, 1321 insertions(+), 117 deletions(-) create mode 100644 lambda/__init__.py create mode 100644 lambda/repository/pipeline_ingest_documents.py create mode 100644 lambda/repository/state_machine/__init__.py create mode 100644 lambda/repository/state_machine/list_modified_objects.py create mode 100644 lambda/repository/state_machine/pipeline_ingest_documents.py create mode 100644 lambda/utilities/create_env_variables.py create mode 100644 lambda/utilities/validation.py create mode 100644 lib/rag/state_machine/constants.ts create mode 100644 lib/rag/state_machine/ingest-pipeline.ts diff --git a/lambda/__init__.py b/lambda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index 90d3f9bd..3bf624e1 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -19,17 +19,18 @@ from typing import Any, Dict, List import boto3 -import create_env_variables # noqa: F401 +import requests from botocore.config import Config from lisapy.langchain import LisaOpenAIEmbeddings -from lisapy.utils import get_cert_path -from utilities.common_functions import api_wrapper, get_id_token, retry_config +from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, retry_config from utilities.file_processing import process_record +from utilities.validation import validate_model_name, ValidationError from utilities.vector_store import get_vector_store_client logger = logging.getLogger(__name__) session = boto3.Session() ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) +secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) s3 = session.client( "s3", @@ -54,13 +55,130 @@ def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"] base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve" + cert_path = get_cert_path(iam_client) embedding = LisaOpenAIEmbeddings( - lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=get_cert_path(iam_client) + lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=cert_path ) return embedding +def _get_embeddings_pipeline(model_name: str) -> Any: + """ + Get embeddings for pipeline requests using management token. + + Args: + model_name: Name of the embedding model to use + + Raises: + ValidationError: If model name is invalid + Exception: If API request fails + """ + logger.info("Starting pipeline embeddings request") + validate_model_name(model_name) + + # Create embeddings client that matches LisaOpenAIEmbeddings interface + class PipelineEmbeddings: + def __init__(self) -> None: + try: + # Get the management key secret name from SSM Parameter Store + secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) + secret_name = secret_name_param["Parameter"]["Value"] + + # Get the management token from Secrets Manager using the secret name + secret_response = secrets_client.get_secret_value(SecretId=secret_name) + self.token = secret_response["SecretString"] + + # Get the API endpoint from SSM + lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) + self.base_url = ( + f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" + ) + + # Get certificate path for SSL verification + self.cert_path = get_cert_path(iam_client) + + logger.info("Successfully initialized pipeline embeddings") + except Exception: + logger.error("Failed to initialize pipeline embeddings", exc_info=True) + raise + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + if not texts: + raise ValidationError("No texts provided for embedding") + + logger.info(f"Embedding {len(texts)} documents") + try: + url = f"{self.base_url}/embeddings" + request_data = {"input": texts, "model": model_name} + + response = requests.post( + url, + json=request_data, + headers={"Authorization": self.token, "Content-Type": "application/json"}, + verify=self.cert_path, # Use proper SSL verification + timeout=300, # 5 minute timeout + ) + + if response.status_code != 200: + logger.error(f"Embedding request failed with status {response.status_code}") + logger.error(f"Response content: {response.text}") + raise Exception(f"Embedding request failed with status {response.status_code}") + + result = response.json() + logger.debug(f"API Response: {result}") # Log the full response for debugging + + # Handle different response formats + embeddings = [] + if isinstance(result, dict): + if "data" in result: + # OpenAI-style format + for item in result["data"]: + if isinstance(item, dict) and "embedding" in item: + embeddings.append(item["embedding"]) + else: + embeddings.append(item) # Assume the item itself is the embedding + else: + # Try to find embeddings in the response + for key in ["embeddings", "embedding", "vectors", "vector"]: + if key in result: + embeddings = result[key] + break + elif isinstance(result, list): + # Direct list format + embeddings = result + + if not embeddings: + logger.error(f"Could not find embeddings in response: {result}") + raise Exception("No embeddings found in API response") + + if len(embeddings) != len(texts): + logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})") + raise Exception("Number of embeddings does not match number of input texts") + + logger.info(f"Successfully embedded {len(texts)} documents") + return embeddings + + except requests.Timeout: + logger.error("Embedding request timed out") + raise Exception("Embedding request timed out after 5 minutes") + except requests.RequestException as e: + logger.error(f"Request failed: {str(e)}", exc_info=True) + raise + except Exception as e: + logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True) + raise + + def embed_query(self, text: str) -> List[float]: + if not text or not isinstance(text, str): + raise ValidationError("Invalid query text") + + logger.info("Embedding single query text") + return self.embed_documents([text])[0] + + return PipelineEmbeddings() + + @api_wrapper def list_all(event: dict, context: dict) -> List[Dict[str, Any]]: """Return info on all available repositories. diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py new file mode 100644 index 00000000..5f2c2eb6 --- /dev/null +++ b/lambda/repository/pipeline_ingest_documents.py @@ -0,0 +1,169 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda function for pipeline document ingestion.""" + +import logging +import os +from typing import Any, Dict, List + +import boto3 +from utilities.common_functions import retry_config +from utilities.file_processing import process_record +from utilities.validation import ( + safe_error_response, + validate_chunk_params, + validate_model_name, + validate_repository_type, + ValidationError, +) +from utilities.vector_store import get_vector_store_client + +from .lambda_functions import _get_embeddings_pipeline + +logger = logging.getLogger(__name__) +session = boto3.Session() +ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) + + +def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]: + """ + Split texts and metadata into batches of specified size. + + Args: + texts: List of text strings to batch + metadatas: List of metadata dictionaries + batch_size: Maximum size of each batch + Returns: + List of tuples containing (texts_batch, metadatas_batch) + """ + batches = [] + for i in range(0, len(texts), batch_size): + text_batch = texts[i: i + batch_size] + metadata_batch = metadatas[i: i + batch_size] + batches.append((text_batch, metadata_batch)) + return batches + + +def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Handle pipeline document ingestion. + + Process a single document from the Map state by chunking and storing in vectorstore. + Reuses existing document processing and vectorstore infrastructure. + Configuration is provided through environment variables set by the state machine. + + Args: + event: Event containing the document bucket and key + context: Lambda context + + Returns: + Dictionary with status code and response body + """ + try: + # Get document location from event + if "bucket" not in event or "key" not in event: + raise ValidationError("Missing required fields: bucket and key") + + bucket = event["bucket"] + key = event["key"] + s3_key = f"s3://{bucket}/{key}" + + # Get all configuration from environment variables + required_env_vars = ["CHUNK_SIZE", "CHUNK_OVERLAP", "EMBEDDING_MODEL", "REPOSITORY_TYPE", "REPOSITORY_ID"] + missing_vars = [var for var in required_env_vars if var not in os.environ] + if missing_vars: + raise ValidationError(f"Missing required environment variables: {', '.join(missing_vars)}") + + chunk_size = int(os.environ["CHUNK_SIZE"]) + chunk_overlap = int(os.environ["CHUNK_OVERLAP"]) + embedding_model = os.environ["EMBEDDING_MODEL"] + repository_type = os.environ["REPOSITORY_TYPE"] + repository_id = os.environ["REPOSITORY_ID"] + + # Validate inputs + validate_model_name(embedding_model) + validate_repository_type(repository_type) + validate_chunk_params(chunk_size, chunk_overlap) + + logger.info(f"Processing document {s3_key} for repository {repository_id} of type {repository_type}") + + # Process document using existing utilities, passing the bucket explicitly + docs = process_record( + s3_keys=[key], # Changed from s3_key to just key since process_record expects just the key + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + bucket=bucket, # Pass the bucket explicitly + ) + + if not docs or not docs[0]: + raise ValidationError(f"No content extracted from document {s3_key}") + + # Prepare texts and metadata + texts = [] + metadatas = [] + for doc_list in docs: + for doc in doc_list: + texts.append(doc.page_content) + # Add repository ID to metadata + doc.metadata["repository_id"] = repository_id + metadatas.append(doc.metadata) + + # Get embeddings using pipeline-specific function that uses IAM auth + embeddings = _get_embeddings_pipeline(model_name=embedding_model) + + # Initialize vector store using model name as index, matching lambda_functions.py pattern + vs = get_vector_store_client( + store=repository_type, # Changed from repository_type to store + index=embedding_model, # Use model name as index to match lambda_functions.py + embeddings=embeddings, + ) + + # Process documents in batches + all_ids = [] + batches = batch_texts(texts, metadatas) + total_batches = len(batches) + + logger.info(f"Processing {len(texts)} texts in {total_batches} batches") + + for i, (text_batch, metadata_batch) in enumerate(batches, 1): + logger.info(f"Processing batch {i}/{total_batches} with {len(text_batch)} texts") + batch_ids = vs.add_texts(texts=text_batch, metadatas=metadata_batch) + if not batch_ids: + raise Exception(f"Failed to store documents in vector store for batch {i}") + all_ids.extend(batch_ids) + logger.info(f"Successfully processed batch {i}") + + if not all_ids: + raise Exception("Failed to store any documents in vector store") + + logger.info(f"Successfully processed {len(all_ids)} chunks from {s3_key} for repository {repository_id}") + + return { + "statusCode": 200, + "body": { + "message": f"Successfully processed document {s3_key}", + "repository_id": repository_id, + "repository_type": repository_type, + "chunks_processed": len(all_ids), + "document_ids": all_ids, + }, + } + + except ValidationError as e: + # Return 400 for validation errors + return safe_error_response(e) + except Exception as e: + # Return 500 for other errors + return safe_error_response(e) diff --git a/lambda/repository/state_machine/__init__.py b/lambda/repository/state_machine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lambda/repository/state_machine/list_modified_objects.py b/lambda/repository/state_machine/list_modified_objects.py new file mode 100644 index 00000000..81228ca7 --- /dev/null +++ b/lambda/repository/state_machine/list_modified_objects.py @@ -0,0 +1,169 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lambda handlers for ListModifiedObjects state machine.""" + +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict + +import boto3 +from utilities.validation import safe_error_response, ValidationError + +logger = logging.getLogger(__name__) + + +def normalize_prefix(prefix: str) -> str: + """ + Normalize the S3 prefix by handling trailing slashes. + + Args: + prefix: S3 prefix to normalize + + Returns: + Normalized prefix + """ + if not prefix: + return "" + + # Remove leading/trailing slashes and spaces + prefix = prefix.strip().strip("/") + + # If prefix is not empty, ensure it ends with a slash + if prefix: + prefix = f"{prefix}/" + + return prefix + + +def validate_bucket_prefix(bucket: str, prefix: str) -> bool: + """ + Validate bucket and prefix parameters. + + Args: + bucket: S3 bucket name + prefix: S3 prefix + + Returns: + bool: True if valid + + Raises: + ValidationError: If parameters are invalid + """ + if not bucket or not isinstance(bucket, str): + raise ValidationError(f"Invalid bucket name: {bucket}") + + if prefix is None or not isinstance(prefix, str): + raise ValidationError(f"Invalid prefix: {prefix}") + + # Basic path traversal check + if ".." in prefix: + raise ValidationError(f"Invalid prefix: path traversal detected in {prefix}") + + return True + + +def handle_list_modified_objects(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. + + Args: + event: Event data containing bucket and prefix information + context: Lambda context + + Returns: + Dictionary containing array of files with their bucket and key + """ + try: + # Log the full event for debugging + logger.debug(f"Received event: {event}") + + # Extract bucket and prefix from event, handling both event types + detail = event.get("detail", {}) + + # Handle both direct bucket name and nested bucket structure + bucket = detail.get("bucket") + if isinstance(bucket, dict): + bucket = bucket.get("name") + + # For event triggers, use the object key as prefix if no prefix specified + prefix = detail.get("prefix") + if not prefix and "object" in detail: + prefix = detail["object"].get("key", "") + + # Normalize the prefix + prefix = normalize_prefix(prefix) + + # Add debug logging + logger.info(f"Processing request for bucket: {bucket}, normalized prefix: {prefix}") + + # Validate inputs + validate_bucket_prefix(bucket, prefix) + + # Initialize S3 client + s3_client = boto3.client("s3") + + # Calculate timestamp for 24 hours ago + twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24) + + # List to store matching objects + modified_files = [] + + # Use paginator to handle case where there are more than 1000 objects + paginator = s3_client.get_paginator("list_objects_v2") + + # Add debug logging for S3 list operation + logger.info(f"Listing objects in {bucket}/{prefix} modified after {twenty_four_hours_ago}") + + # Iterate through all objects in the bucket/prefix + try: + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + if "Contents" not in page: + logger.info(f"No contents found in page for {bucket}/{prefix}") + continue + + # Check each object's last modified time + for obj in page["Contents"]: + last_modified = obj["LastModified"] + if last_modified >= twenty_four_hours_ago: + logger.info(f"Found modified file: {obj['Key']} (Last Modified: {last_modified})") + modified_files.append({"bucket": bucket, "key": obj["Key"]}) + else: + logger.debug( + f"Skipping file {obj['Key']} - Last modified {last_modified} before cutoff " + f"{twenty_four_hours_ago}" + ) + except Exception as e: + logger.error(f"Error during S3 list operation: {str(e)}", exc_info=True) + raise + + result = { + "files": modified_files, + "metadata": { + "bucket": bucket, + "prefix": prefix, + "cutoff_time": twenty_four_hours_ago.isoformat(), + "files_found": len(modified_files), + }, + } + + logger.info(f"Found {len(modified_files)} modified files in {bucket}/{prefix}") + return result + + except ValidationError as e: + logger.error(f"Validation error: {str(e)}") + return safe_error_response(e) + except Exception as e: + logger.error(f"Error listing objects: {str(e)}", exc_info=True) + return safe_error_response(e) diff --git a/lambda/repository/state_machine/pipeline_ingest_documents.py b/lambda/repository/state_machine/pipeline_ingest_documents.py new file mode 100644 index 00000000..ec681144 --- /dev/null +++ b/lambda/repository/state_machine/pipeline_ingest_documents.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lambda handlers for PipelineIngestDocuments state machine.""" + +import os +from typing import Any, Dict + +import boto3 +from models.document_processor import DocumentProcessor +from models.vectorstore import VectorStore + + +def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Process a single document from the Map state by chunking and storing in vectorstore. + + Args: + event: Event containing the document bucket and key + context: Lambda context + + Returns: + Dictionary indicating success/failure + """ + try: + # Get document location from event + bucket = event["bucket"] + key = event["key"] + + # Get configuration from environment variables + chunk_size = int(os.environ["CHUNK_SIZE"]) + chunk_overlap = int(os.environ["CHUNK_OVERLAP"]) + embedding_model = os.environ["EMBEDDING_MODEL"] + collection_name = os.environ["COLLECTION_NAME"] + + # Initialize document processor and vectorstore + doc_processor = DocumentProcessor() + vectorstore = VectorStore(collection_name=collection_name, embedding_model=embedding_model) + + # Download and process document + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + content = response["Body"].read().decode("utf-8") + + # Chunk document + chunks = doc_processor.chunk_text(text=content, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + # Store chunks in vectorstore + vectorstore.add_texts(texts=chunks, metadata={"source": f"s3://{bucket}/{key}"}) + + return { + "statusCode": 200, + "body": { + "message": f"Successfully processed document s3://{bucket}/{key}", + "chunks_processed": len(chunks), + }, + } + + except Exception as e: + print(f"Error processing document: {str(e)}") + raise diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index 0acffc0c..66308996 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -24,9 +24,10 @@ from typing import Any, Callable, Dict, TypeVar, Union import boto3 -import create_env_variables # noqa type: ignore from botocore.config import Config +from . import create_env_variables # noqa type: ignore + retry_config = Config( retries={ "max_attempts": 3, @@ -289,29 +290,52 @@ def get_id_token(event: dict) -> str: return str(auth_header).removeprefix("Bearer ").removeprefix("bearer ").strip() +_cert_file = None + + @cache def get_cert_path(iam_client: Any) -> Union[str, bool]: """ Get cert path for IAM certs for SSL validation against LISA Serve endpoint. - If no SSL Cert ARN is specified just default verify to true and the cert will need to be - signed by a known CA. Assume cert is signed with known CA if coming from ACM. - - Note: this function is a copy of the same function in the lisa-sdk path. To avoid inflating the deployment size of - the Lambda functions, this function was copied here instead of including the entire lisa-sdk path. + Returns the path to the certificate file for SSL verification, or True to use + default verification if no certificate ARN is specified. """ - cert_arn = os.environ.get("RESTAPI_SSL_CERT_ARN", "") - if not cert_arn or cert_arn.split(":")[2] == "acm": - return True + global _cert_file - # We have the arn, but we need the name which is the last part of the arn - rest_api_cert = iam_client.get_server_certificate(ServerCertificateName=cert_arn.split("/")[1]) - cert_body = rest_api_cert["ServerCertificate"]["CertificateBody"] - cert_file = tempfile.NamedTemporaryFile(delete=False) - cert_file.write(cert_body.encode("utf-8")) - rest_api_cert_path = cert_file.name + cert_arn = os.environ.get("RESTAPI_SSL_CERT_ARN") + if not cert_arn: + logger.info("No SSL certificate ARN specified, using default verification") + return True - return rest_api_cert_path + try: + # Clean up previous cert file if it exists + if _cert_file and os.path.exists(_cert_file.name): + try: + os.unlink(_cert_file.name) + except Exception as e: + logger.warning(f"Failed to clean up previous cert file: {e}") + + # Get the certificate name from the ARN + cert_name = cert_arn.split("/")[1] + logger.info(f"Retrieving certificate '{cert_name}' from IAM") + + # Get the certificate from IAM + rest_api_cert = iam_client.get_server_certificate(ServerCertificateName=cert_name) + cert_body = rest_api_cert["ServerCertificate"]["CertificateBody"] + + # Create a new temporary file + _cert_file = tempfile.NamedTemporaryFile(delete=False) + _cert_file.write(cert_body.encode("utf-8")) + _cert_file.flush() + + logger.info(f"Certificate saved to temporary file: {_cert_file.name}") + return _cert_file.name + + except Exception as e: + logger.error(f"Failed to get certificate from IAM: {e}", exc_info=True) + # If we fail to get the cert, return True to fall back to default verification + return True @cache diff --git a/lambda/utilities/create_env_variables.py b/lambda/utilities/create_env_variables.py new file mode 100644 index 00000000..549d0a3d --- /dev/null +++ b/lambda/utilities/create_env_variables.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module to create and set environment variables for Lambda functions.""" +import os + +import boto3 + + +def setup_environment() -> None: + """Set up environment variables needed by Lambda functions.""" + # Set up SSL certificate path if not already set + if "SSL_CERT_FILE" not in os.environ: + # Default to the Amazon root CA bundle location in Lambda + os.environ["SSL_CERT_FILE"] = "/etc/pki/tls/certs/ca-bundle.crt" + + # Set up any other common environment variables here + if "AWS_REGION" not in os.environ: + session = boto3.Session() + os.environ["AWS_REGION"] = session.region_name or "us-east-1" + + +# Run setup when module is imported +setup_environment() diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index 7f9102a1..f82640c2 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -124,15 +124,25 @@ def _extract_docx_content(s3_object: dict) -> str: return output -def process_record(s3_keys: List[str], chunk_size: Optional[int], chunk_overlap: Optional[int]) -> List[List[Document]]: +def process_record( + s3_keys: List[str], chunk_size: Optional[int], chunk_overlap: Optional[int], bucket: Optional[str] = None +) -> List[List[Document]]: """Process a single file from S3. Parameters ---------- - record (dict): The S3 record to process. + s3_keys (List[str]): List of S3 keys to process + chunk_size (Optional[int]): Size of chunks to split text into + chunk_overlap (Optional[int]): Number of characters to overlap between chunks + bucket (Optional[str]): S3 bucket name. If not provided, uses os.environ["BUCKET_NAME"] + Returns + ------- + List[List[Document]]: List of document chunks for each processed file """ - bucket = os.environ["BUCKET_NAME"] + if bucket is None: + bucket = os.environ["BUCKET_NAME"] + chunks = [] for key in s3_keys: content_type = key.split(".")[-1] diff --git a/lambda/utilities/validation.py b/lambda/utilities/validation.py new file mode 100644 index 00000000..11c7a37f --- /dev/null +++ b/lambda/utilities/validation.py @@ -0,0 +1,158 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validation utilities for Lambda functions.""" +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class ValidationError(Exception): + """Custom exception for validation errors.""" + + pass + + +class SecurityError(Exception): + """Custom exception for security-related errors.""" + + pass + + +def validate_model_name(model_name: str) -> bool: + """Validate model name is a non-empty string. + + Args: + model_name: Name of the model to validate + + Returns: + bool: True if valid + + Raises: + ValidationError: If model name is invalid + """ + if not isinstance(model_name, str): + raise ValidationError("Model name must be a string") + + if not model_name or model_name.isspace(): + raise ValidationError("Model name cannot be empty") + + return True + + +def validate_repository_type(repo_type: str) -> bool: + """Validate repository type against allowed types. + + Args: + repo_type: Repository type to validate + + Returns: + bool: True if valid + + Raises: + ValidationError: If repository type is invalid + """ + ALLOWED_TYPES = ["opensearch", "pgvector"] + + if not isinstance(repo_type, str): + raise ValidationError("Repository type must be a string") + + if repo_type not in ALLOWED_TYPES: + raise ValidationError(f"Invalid repository type. Must be one of: {ALLOWED_TYPES}") + + return True + + +def validate_s3_key(key: str) -> bool: + """Validate S3 key format and allowed extensions. + + Args: + key: S3 key to validate + + Returns: + bool: True if valid + + Raises: + ValidationError: If key is invalid + """ + ALLOWED_EXTENSIONS = [".txt", ".pdf", ".docx"] + + if not isinstance(key, str): + raise ValidationError("S3 key must be a string") + + if not key or key.isspace(): + raise ValidationError("S3 key cannot be empty") + + if not any(key.lower().endswith(ext) for ext in ALLOWED_EXTENSIONS): + raise ValidationError(f"Invalid file type. Must be one of: {ALLOWED_EXTENSIONS}") + + # Basic path traversal check + if ".." in key: + raise SecurityError("Path traversal detected in S3 key") + + return True + + +def validate_chunk_params(chunk_size: Optional[int], chunk_overlap: Optional[int]) -> bool: + """ + Validate chunking parameters. + + Args: + chunk_size: Size of chunks + chunk_overlap: Overlap between chunks + + Returns: + bool: True if valid + + Raises: + ValidationError: If parameters are invalid + """ + if chunk_size is not None: + if not isinstance(chunk_size, int): + raise ValidationError("Chunk size must be an integer") + + if chunk_size < 100 or chunk_size > 10000: + raise ValidationError("Chunk size must be between 100 and 10000") + + if chunk_overlap is not None: + if not isinstance(chunk_overlap, int): + raise ValidationError("Chunk overlap must be an integer") + + if chunk_overlap < 0: + raise ValidationError("Chunk overlap cannot be negative") + + if chunk_size and chunk_overlap >= chunk_size: + raise ValidationError("Chunk overlap must be less than chunk size") + + return True + + +def safe_error_response(error: Exception) -> dict: + """Create a safe error response that doesn't leak implementation details. + + Args: + error: The exception that occurred + + Returns: + dict: Sanitized error response + """ + if isinstance(error, ValidationError): + return {"statusCode": 400, "body": {"message": str(error)}} + elif isinstance(error, SecurityError): + return {"statusCode": 403, "body": {"message": "Security validation failed"}} + else: + # Log the full error internally but return generic message + logger.error(f"Internal error: {str(error)}", exc_info=True) + return {"statusCode": 500, "body": {"message": "Internal server error"}} diff --git a/lib/core/iam/rag.json b/lib/core/iam/rag.json index 84cbf0bd..db5451a2 100644 --- a/lib/core/iam/rag.json +++ b/lib/core/iam/rag.json @@ -1,25 +1,71 @@ { - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - "logs:CreateLogGroup", - "logs:CreateLogStream", - "logs:PutLogEvents", - "ec2:CreateNetworkInterface", - "ec2:DescribeNetworkInterfaces", - "ec2:DescribeSubnets", - "ec2:DeleteNetworkInterface", - "ec2:AssignPrivateIpAddresses", - "ec2:UnassignPrivateIpAddresses" - ], - "Resource": "*" - }, - { - "Action": ["iam:GetServerCertificate"], - "Resource": "arn:${AWS::Partition}:iam::${AWS::AccountId}:server-certificate/*", - "Effect": "Allow" - } - ] + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "logs:CreateLogGroup", + "logs:CreateLogStream", + "logs:PutLogEvents" + ], + "Resource": "*" + }, + { + "Effect": "Allow", + "Action": [ + "ec2:CreateNetworkInterface", + "ec2:DescribeNetworkInterfaces", + "ec2:DescribeSubnets", + "ec2:DeleteNetworkInterface", + "ec2:AssignPrivateIpAddresses", + "ec2:UnassignPrivateIpAddresses" + ], + "Resource": "*" + }, + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:ListBucket" + ], + "Resource": [ + "arn:${AWS::Partition}:s3:::${S3Bucket}", + "arn:${AWS::Partition}:s3:::${S3Bucket}/*" + ] + }, + { + "Effect": "Allow", + "Action": [ + "secretsmanager:GetSecretValue" + ], + "Resource": "*" + }, + { + "Effect": "Allow", + "Action": [ + "execute-api:Invoke" + ], + "Resource": [ + "arn:${AWS::Partition}:execute-api:${AWS::Region}:${AWS::AccountId}:*/*/POST/serve/embeddings" + ] + }, + { + "Effect": "Allow", + "Action": [ + "ssm:GetParameter" + ], + "Resource": [ + "arn:${AWS::Partition}:ssm:${AWS::Region}:${AWS::AccountId}:parameter/dev/dev/lisa/lisaServeRestApiUri", + "arn:${AWS::Partition}:ssm:${AWS::Region}:${AWS::AccountId}:parameter/dev/dev/lisa/LisaServeRagPGVectorConnectionInfo", + "arn:${AWS::Partition}:ssm:${AWS::Region}:${AWS::AccountId}:parameter/dev/dev/lisa/opensearchEndpoint" + ] + }, + { + "Effect": "Allow", + "Action": [ + "iam:GetServerCertificate" + ], + "Resource": "arn:${AWS::Partition}:iam::${AWS::AccountId}:server-certificate/*" + } + ] } diff --git a/lib/rag/index.ts b/lib/rag/index.ts index 2db54b2d..73c5083b 100644 --- a/lib/rag/index.ts +++ b/lib/rag/index.ts @@ -24,7 +24,7 @@ import { CfnOutput, RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup, Peer, Port, SecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { AnyPrincipal, CfnServiceLinkedRole, Effect, PolicyStatement, Role } from 'aws-cdk-lib/aws-iam'; -import { Code, LayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda'; +import { Code, LayerVersion, Runtime, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; import { Domain, EngineVersion, IDomain } from 'aws-cdk-lib/aws-opensearchservice'; import { Credentials, DatabaseInstance, DatabaseInstanceEngine } from 'aws-cdk-lib/aws-rds'; import { Bucket, HttpMethods } from 'aws-cdk-lib/aws-s3'; @@ -40,6 +40,8 @@ import { createCdkId } from '../core/utils'; import { Vpc } from '../networking/vpc'; import { BaseProps, RagRepositoryType } from '../schema'; +import { IngestPipelineStateMachine } from './state_machine/ingest-pipeline'; + const HERE = path.resolve(__dirname); const RAG_LAYER_PATH = path.join(HERE, 'layer'); const SDK_PATH: string = path.resolve(HERE, '..', '..', 'lisa-sdk'); @@ -117,6 +119,34 @@ export class LisaRagStack extends Stack { bucket.grantRead(lambdaRole); bucket.grantPut(lambdaRole); + // Build RAG Lambda layer + const ragLambdaLayer = new Layer(this, 'RagLayer', { + config: config, + path: RAG_LAYER_PATH, + description: 'Lambad dependencies for RAG API', + architecture: ARCHITECTURE, + autoUpgrade: true, + assetPath: config.lambdaLayerAssets?.ragLayerPath, + }); + + // Build SDK Layer + let sdkLayer: ILayerVersion; + if (config.lambdaLayerAssets?.sdkLayerPath) { + sdkLayer = new LayerVersion(this, 'SdkLayer', { + code: Code.fromAsset(config.lambdaLayerAssets?.sdkLayerPath), + compatibleRuntimes: [Runtime.PYTHON_3_10], + removalPolicy: config.removalPolicy, + description: 'LISA SDK common layer', + }); + } else { + sdkLayer = new PythonLayerVersion(this, 'SdkLayer', { + entry: SDK_PATH, + compatibleRuntimes: [Runtime.PYTHON_3_10], + removalPolicy: config.removalPolicy, + description: 'LISA SDK common layer', + }); + } + const registeredRepositories = []; for (const ragConfig of config.ragRepositories) { @@ -172,7 +202,7 @@ export class LisaRagStack extends Stack { version: EngineVersion.OPENSEARCH_2_9, enableVersionUpgrade: true, vpc: vpc.vpc, - vpcSubnets: vpc.subnetSelection ? [vpc.subnetSelection] : [], + ...vpc.subnetSelection ? {vpcSubnets: [vpc.subnetSelection]} : {}, ebs: { enabled: true, volumeSize: ragConfig.opensearchConfig.volumeSize, @@ -286,6 +316,38 @@ export class LisaRagStack extends Stack { rdsConnectionInfoPs.grantRead(lambdaRole); baseEnvironment['RDS_CONNECTION_INFO_PS_NAME'] = rdsConnectionInfoPs.parameterName; } + + // Create ingest pipeline state machines for each pipeline config + console.log('[DEBUG] Checking pipelines configuration:', { + hasPipelines: !!ragConfig.pipelines, + pipelinesLength: ragConfig.pipelines?.length || 0 + }); + + if (ragConfig.pipelines) { + ragConfig.pipelines.forEach((pipelineConfig, index) => { + console.log(`[DEBUG] Creating pipeline ${index}:`, { + pipelineConfig: JSON.stringify(pipelineConfig, null, 2) + }); + + try { + // Create a unique ID for each pipeline using repository ID and index + const pipelineId = `IngestPipeline-${ragConfig.repositoryId}-${index}`; + new IngestPipelineStateMachine(this, pipelineId, { + config, + vpc, + pipelineConfig, + rdsConfig: ragConfig.rdsConfig, + repositoryId: ragConfig.repositoryId, + type: ragConfig.type, + layers: [commonLambdaLayer, ragLambdaLayer.layer, sdkLayer] + }); + console.log(`[DEBUG] Successfully created pipeline ${index}`); + } catch (error) { + console.error(`[ERROR] Failed to create pipeline ${index}:`, error); + throw error; // Re-throw to ensure CDK deployment fails + } + }); + } } // Create Parameter Store entry with RAG repositories @@ -297,34 +359,6 @@ export class LisaRagStack extends Stack { baseEnvironment['REGISTERED_REPOSITORIES_PS_NAME'] = ragRepositoriesParam.parameterName; - // Build RAG Lambda layer - const ragLambdaLayer = new Layer(this, 'RagLayer', { - config: config, - path: RAG_LAYER_PATH, - description: 'Lambad dependencies for RAG API', - architecture: ARCHITECTURE, - autoUpgrade: true, - assetPath: config.lambdaLayerAssets?.ragLayerPath, - }); - - // Build SDK Layer - let sdkLayer; - if (config.lambdaLayerAssets?.sdkLayerPath) { - sdkLayer = new LayerVersion(this, 'SdkLayer', { - code: Code.fromAsset(config.lambdaLayerAssets?.sdkLayerPath), - compatibleRuntimes: [Runtime.PYTHON_3_10], - removalPolicy: config.removalPolicy, - description: 'LISA SDK common layer', - }); - } else { - sdkLayer = new PythonLayerVersion(this, 'SdkLayer', { - entry: SDK_PATH, - compatibleRuntimes: [Runtime.PYTHON_3_10], - removalPolicy: config.removalPolicy, - description: 'LISA SDK common layer', - }); - } - // Add REST API Lambdas to APIGW new RepositoryApi(this, 'RepositoryApi', { authorizer, diff --git a/lib/rag/state_machine/constants.ts b/lib/rag/state_machine/constants.ts new file mode 100644 index 00000000..706d91e0 --- /dev/null +++ b/lib/rag/state_machine/constants.ts @@ -0,0 +1,23 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +import { Duration } from 'aws-cdk-lib'; +import { WaitTime } from 'aws-cdk-lib/aws-stepfunctions'; + +export const LAMBDA_MEMORY: number = 2048; +export const LAMBDA_TIMEOUT: Duration = Duration.minutes(15); +export const OUTPUT_PATH: string = '$.Payload'; +export const POLLING_TIMEOUT: WaitTime = WaitTime.duration(Duration.seconds(60)); diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts new file mode 100644 index 00000000..fd5a04e5 --- /dev/null +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -0,0 +1,315 @@ +/** + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { + Choice, + Condition, + DefinitionBody, + Fail, + StateMachine, + Succeed, + Map, + Pass, + Chain, +} from 'aws-cdk-lib/aws-stepfunctions'; +import { Construct } from 'constructs'; +import { Duration } from 'aws-cdk-lib'; +import { BaseProps } from '../../schema'; +import { Code, Function, ILayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda'; +import { Effect, PolicyStatement } from 'aws-cdk-lib/aws-iam'; +import { LAMBDA_MEMORY, LAMBDA_TIMEOUT, OUTPUT_PATH } from './constants'; +import { Vpc } from '../../networking/vpc'; +import { LambdaInvoke } from 'aws-cdk-lib/aws-stepfunctions-tasks'; +import { Rule, Schedule, EventPattern, RuleTargetInput, EventField } from 'aws-cdk-lib/aws-events'; +import { SfnStateMachine } from 'aws-cdk-lib/aws-events-targets'; +import { RagRepositoryType } from '../../schema'; +import * as kms from 'aws-cdk-lib/aws-kms'; + +type PipelineConfig = { + chunkOverlap: number; + chunkSize: number; + embeddingModel: string; + s3Bucket: string; + s3Prefix: string; + trigger: string; + collectionName: string; +}; + +type RdsConfig = { + username: string; + dbHost?: string; + dbName: string; + dbPort: number; + passwordSecretId?: string; +}; + +type IngestPipelineStateMachineProps = BaseProps & { + vpc?: Vpc; + pipelineConfig: PipelineConfig; + rdsConfig?: RdsConfig; + repositoryId: string; + type: RagRepositoryType; + layers?: ILayerVersion[]; +}; + +/** + * State Machine for creating models. + */ +export class IngestPipelineStateMachine extends Construct { + readonly stateMachineArn: string; + + constructor (scope: Construct, id: string, props: IngestPipelineStateMachineProps) { + super(scope, id); + + const {config, vpc, pipelineConfig, rdsConfig, repositoryId, type, layers} = props; + + // Create KMS key for environment variable encryption + const kmsKey = new kms.Key(this, 'EnvironmentEncryptionKey', { + enableKeyRotation: true, + description: 'Key for encrypting Lambda environment variables' + }); + + const environment = { + CHUNK_OVERLAP: pipelineConfig.chunkOverlap.toString(), + CHUNK_SIZE: pipelineConfig.chunkSize.toString(), + EMBEDDING_MODEL: pipelineConfig.embeddingModel, + S3_BUCKET: pipelineConfig.s3Bucket, + S3_PREFIX: pipelineConfig.s3Prefix, + REPOSITORY_ID: repositoryId, + REPOSITORY_TYPE: type, + REST_API_VERSION: 'v2', + MANAGEMENT_KEY_SECRET_NAME_PS: `${config.deploymentPrefix}/managementKeySecretName`, + RDS_CONNECTION_INFO_PS_NAME: `${config.deploymentPrefix}/LisaServeRagPGVectorConnectionInfo`, + OPENSEARCH_ENDPOINT_PS_NAME: `${config.deploymentPrefix}/lisaServeRagRepositoryEndpoint`, + LISA_API_URL_PS_NAME: `${config.deploymentPrefix}/lisaServeRestApiUri`, + LOG_LEVEL: config.logLevel, + RESTAPI_SSL_CERT_ARN: config.restApiConfig.sslCertIamArn || '', + ...(rdsConfig && { + RDS_USERNAME: rdsConfig.username, + RDS_HOST: rdsConfig.dbHost || '', + RDS_DATABASE: rdsConfig.dbName, + RDS_PORT: rdsConfig.dbPort.toString(), + RDS_PASSWORD_SECRET_ID: rdsConfig.passwordSecretId || '' + }) + }; + + // Create S3 policy statement for both functions + const s3PolicyStatement = new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['s3:GetObject', 's3:ListBucket'], + resources: [ + `arn:aws:s3:::${pipelineConfig.s3Bucket}`, + `arn:aws:s3:::${pipelineConfig.s3Bucket}/*` + ] + }); + + // Create array of policy statements + const policyStatements = [s3PolicyStatement]; + + // Create IAM certificate policy if certificate ARN is provided + let certPolicyStatement; + if (config.restApiConfig.sslCertIamArn) { + certPolicyStatement = new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['iam:GetServerCertificate'], + resources: [config.restApiConfig.sslCertIamArn] + }); + policyStatements.push(certPolicyStatement); + } + + // Function to list objects modified in last 24 hours + const listModifiedObjectsFunction = new Function(this, 'listModifiedObjectsFunc', { + runtime: Runtime.PYTHON_3_10, + handler: 'repository.state_machine.list_modified_objects.handle_list_modified_objects', + code: Code.fromAsset('./lambda'), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + vpc: vpc!.vpc, + environment: environment, + environmentEncryption: kmsKey, + layers: layers, + initialPolicy: policyStatements + }); + + const listModifiedObjects = new LambdaInvoke(this, 'listModifiedObjects', { + lambdaFunction: listModifiedObjectsFunction, + outputPath: OUTPUT_PATH, + }); + + // Create a Pass state to normalize event structure for single file processing + const prepareSingleFile = new Pass(this, 'PrepareSingleFile', { + parameters: { + 'files': [{ + 'bucket': pipelineConfig.s3Bucket, + 'key.$': '$.detail.object.key' + }] + } + }); + + // Create the ingest documents function with S3 permissions + const pipelineIngestDocumentsFunction = new Function(this, 'pipelineIngestDocumentsMapFunc', { + runtime: Runtime.PYTHON_3_10, + handler: 'repository.pipeline_ingest_documents.handle_pipeline_ingest_documents', + code: Code.fromAsset('./lambda'), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + vpc: vpc!.vpc, + environment: environment, + environmentEncryption: kmsKey, + layers: layers, + initialPolicy: [ + ...policyStatements, // Include all base policies including certificate policy + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['ssm:GetParameter'], + resources: [ + `arn:aws:ssm:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:parameter${config.deploymentPrefix}/LisaServeRagPGVectorConnectionInfo`, + `arn:aws:ssm:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:parameter${config.deploymentPrefix}/lisaServeRagRepositoryEndpoint`, + `arn:aws:ssm:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:parameter${config.deploymentPrefix}/lisaServeRestApiUri`, + `arn:aws:ssm:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:parameter${config.deploymentPrefix}/managementKeySecretName` + ] + }), + new PolicyStatement({ + effect: Effect.ALLOW, + actions: ['secretsmanager:GetSecretValue'], + resources: [ + `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}-lisa-management-key*`, + `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret*` + ] + }) + ] + }); + + const pipelineIngestDocumentsMap = new LambdaInvoke(this, 'pipelineIngestDocumentsMap', { + lambdaFunction: pipelineIngestDocumentsFunction, + outputPath: OUTPUT_PATH, + }); + + const handleFailureState = new LambdaInvoke(this, 'HandleFailure', { + lambdaFunction: new Function(this, 'HandleFailureFunc', { + runtime: Runtime.PYTHON_3_10, + handler: 'repository.state_machine.create_model.handle_failure', + code: Code.fromAsset('./lambda'), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + vpc: vpc!.vpc, + environment: environment, + environmentEncryption: kmsKey, + layers: layers, + }), + outputPath: OUTPUT_PATH, + }); + + const successState = new Succeed(this, 'CreateSuccess'); + const failState = new Fail(this, 'CreateFailed'); + + // Map state for distributed processing with rate limiting + const processFiles = new Map(this, 'ProcessFiles', { + maxConcurrency: 5, // Reduced from 10 for better rate limiting + itemsPath: '$.files', + }); + processFiles.iterator(pipelineIngestDocumentsMap); + + // Choice state to determine trigger type + const triggerChoice = new Choice(this, 'DetermineTriggerType') + .when(Condition.stringEquals('$.detail.trigger', 'daily'), listModifiedObjects) + .otherwise(prepareSingleFile); + + // Build the chain + const definition = Chain + .start(triggerChoice); + + listModifiedObjects.next(processFiles); + prepareSingleFile.next(processFiles); + processFiles.next(successState); + handleFailureState.next(failState); + + const stateMachine = new StateMachine(this, 'IngestPipeline', { + definitionBody: DefinitionBody.fromChainable(definition), + timeout: Duration.hours(2), + }); + + // Add EventBridge Rules based on pipeline configuration + if (pipelineConfig.trigger === 'daily') { + // Create daily cron trigger with input template + new Rule(this, 'DailyIngestRule', { + schedule: Schedule.cron({ + minute: '0', + hour: '0' + }), + targets: [new SfnStateMachine(stateMachine, { + input: RuleTargetInput.fromObject({ + version: '0', + id: EventField.eventId, + 'detail-type': 'Scheduled Event', + source: 'aws.events', + time: EventField.time, + region: EventField.region, + detail: { + bucket: pipelineConfig.s3Bucket, + prefix: pipelineConfig.s3Prefix, + trigger: 'daily' + } + }) + })] + }); + } else if (pipelineConfig.trigger === 'event') { + // Create S3 event trigger with complete event pattern and transform input + const detail: any = { + bucket: { + name: [pipelineConfig.s3Bucket] + } + }; + + // Add prefix filter if specified and not root + if (pipelineConfig.s3Prefix && pipelineConfig.s3Prefix !== '/') { + detail.object = { + key: [{ + prefix: pipelineConfig.s3Prefix + }] + }; + } + + const eventPattern: EventPattern = { + source: ['aws.s3'], + detailType: ['Object Created', 'Object Modified'], + detail + }; + + new Rule(this, 'S3EventIngestRule', { + eventPattern, + targets: [new SfnStateMachine(stateMachine, { + input: RuleTargetInput.fromObject({ + 'detail-type': EventField.detailType, + source: EventField.source, + time: EventField.time, + region: EventField.region, + detail: { + bucket: pipelineConfig.s3Bucket, + prefix: pipelineConfig.s3Prefix, + object: { + key: EventField.fromPath('$.detail.object.key') + }, + trigger: 'event' + } + }) + })] + }); + } + + this.stateMachineArn = stateMachine.stateMachineArn; + } +} diff --git a/lib/schema.ts b/lib/schema.ts index bd34ccb0..fd360714 100644 --- a/lib/schema.ts +++ b/lib/schema.ts @@ -594,22 +594,29 @@ const OpenSearchExistingClusterConfig = z.object({ /** * Configuration schema for RAG repository. Defines settings for OpenSearch. */ -const RagRepositoryConfigSchema = z - .object({ - repositoryId: z.string(), - type: z.nativeEnum(RagRepositoryType), - opensearchConfig: z.union([OpenSearchExistingClusterConfig, OpenSearchNewClusterConfig]).optional(), - rdsConfig: RdsInstanceConfig.optional(), - }) - .refine((input) => { - if ( - (input.type === RagRepositoryType.OPENSEARCH && input.opensearchConfig === undefined) || - (input.type === RagRepositoryType.PGVECTOR && input.rdsConfig === undefined) - ) { - return false; - } - return true; - }); +const RagRepositoryConfigSchema = z.object({ + repositoryId: z.string(), + type: z.nativeEnum(RagRepositoryType), + opensearchConfig: z.union([OpenSearchExistingClusterConfig, OpenSearchNewClusterConfig]).optional(), + rdsConfig: RdsInstanceConfig.optional(), + pipelines: z.array(z.object({ + chunkOverlap: z.number(), + chunkSize: z.number(), + embeddingModel: z.string(), + s3Bucket: z.string(), + s3Prefix: z.string(), + trigger: z.union([z.literal('daily'), z.literal('event')]), + collectionName: z.string() + })).optional() +}).refine((input) => { + if ( + (input.type === RagRepositoryType.OPENSEARCH && input.opensearchConfig === undefined) || + (input.type === RagRepositoryType.PGVECTOR && input.rdsConfig === undefined) + ) { + return false; + } + return true; +}); /** * Configuration schema for RAG file processing. Determines the chunk size and chunk overlap when processing documents. diff --git a/lib/serve/index.ts b/lib/serve/index.ts index 29867f4d..4bdeded8 100644 --- a/lib/serve/index.ts +++ b/lib/serve/index.ts @@ -82,9 +82,10 @@ export class LisaServeApplicationStack extends Stack { vpc: vpc, }); + // Use a stable name for the management key secret const managementKeySecret = new Secret(this, createCdkId([id, 'managementKeySecret']), { - secretName: `lisa_management_key_secret-${Date.now()}`, // pragma: allowlist secret` - description: 'This is a secret created with AWS CDK', + secretName: `${config.deploymentName}-lisa-management-key`, // Use stable name based on deployment + description: 'LISA management key secret', generateSecretString: { excludePunctuation: true, passwordLength: 16 diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index 8a2953bf..f64fd690 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,6 +1,6 @@ # Use an argument for the base image ARG BASE_IMAGE -FROM ${BASE_IMAGE} +FROM --platform=linux/amd64 ${BASE_IMAGE} # Copy LiteLLM config directly out of the LISA config.yaml file ARG LITELLM_CONFIG diff --git a/package-lock.json b/package-lock.json index 0664a897..06900c0c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,6 +16,7 @@ "js-yaml": "^4.1.0", "lodash": "^4.17.21", "source-map-support": "^0.5.21", + "util": "^0.12.5", "zod": "^3.22.3" }, "bin": { @@ -4234,7 +4235,6 @@ "version": "1.0.7", "resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.7.tgz", "integrity": "sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==", - "dev": true, "dependencies": { "possible-typed-array-names": "^1.0.0" }, @@ -4899,7 +4899,6 @@ "version": "1.0.7", "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", - "dev": true, "dependencies": { "es-define-property": "^1.0.0", "es-errors": "^1.3.0", @@ -5436,7 +5435,6 @@ "version": "1.1.4", "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", - "dev": true, "dependencies": { "es-define-property": "^1.0.0", "es-errors": "^1.3.0", @@ -5744,7 +5742,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", - "dev": true, "dependencies": { "get-intrinsic": "^1.2.4" }, @@ -5756,7 +5753,6 @@ "version": "1.3.0", "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", - "dev": true, "engines": { "node": ">= 0.4" } @@ -6526,7 +6522,6 @@ "version": "0.3.3", "resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.3.tgz", "integrity": "sha512-jqYfLp7mo9vIyQf8ykW2v7A+2N4QjeCeI5+Dz9XraiO1ign81wjiH7Fb9vSOWvQfNtmSa4H2RoQTrrXivdUZmw==", - "dev": true, "dependencies": { "is-callable": "^1.1.3" } @@ -6556,7 +6551,6 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", - "dev": true, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -6623,7 +6617,6 @@ "version": "1.2.4", "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", - "dev": true, "dependencies": { "es-errors": "^1.3.0", "function-bind": "^1.1.2", @@ -6822,7 +6815,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", - "dev": true, "dependencies": { "get-intrinsic": "^1.1.3" }, @@ -6864,7 +6856,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", - "dev": true, "dependencies": { "es-define-property": "^1.0.0" }, @@ -6876,7 +6867,6 @@ "version": "1.0.3", "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", - "dev": true, "engines": { "node": ">= 0.4" }, @@ -6888,7 +6878,6 @@ "version": "1.0.3", "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", - "dev": true, "engines": { "node": ">= 0.4" }, @@ -6900,7 +6889,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", - "dev": true, "dependencies": { "has-symbols": "^1.0.3" }, @@ -6915,7 +6903,6 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", - "dev": true, "dependencies": { "function-bind": "^1.1.2" }, @@ -7041,8 +7028,7 @@ "node_modules/inherits": { "version": "2.0.4", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", - "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", - "dev": true + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==" }, "node_modules/ini": { "version": "1.3.8", @@ -7064,6 +7050,21 @@ "node": ">= 0.4" } }, + "node_modules/is-arguments": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-arguments/-/is-arguments-1.1.1.tgz", + "integrity": "sha512-8Q7EARjzEnKpt/PCD7e1cgUS0a6X8u5tdSiMqXhojOdoV9TsMsiO+9VLC5vAmO8N7/GmXn7yjR8qnA6bVAEzfA==", + "dependencies": { + "call-bind": "^1.0.2", + "has-tostringtag": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/is-array-buffer": { "version": "3.0.4", "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.4.tgz", @@ -7118,7 +7119,6 @@ "version": "1.2.7", "resolved": "https://registry.npmjs.org/is-callable/-/is-callable-1.2.7.tgz", "integrity": "sha512-1BC0BVFhS/p0qtw6enp8e+8OD0UrK0oFLztSjNzhcKA3WDuJxxAPXzPuPtKkjEY9UUoEWlX/8fgKeu2S8i9JTA==", - "dev": true, "engines": { "node": ">= 0.4" }, @@ -7202,6 +7202,20 @@ "node": ">=6" } }, + "node_modules/is-generator-function": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/is-generator-function/-/is-generator-function-1.0.10.tgz", + "integrity": "sha512-jsEjy9l3yiXEQ+PsXdmBwEPcOxaXWLspKdplFUVI9vq1iZgIekeC0L167qeu86czQaxed3q/Uzuw0swL0irL8A==", + "dependencies": { + "has-tostringtag": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", @@ -7336,7 +7350,6 @@ "version": "1.1.13", "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.13.tgz", "integrity": "sha512-uZ25/bUAlUY5fR4OKT4rZQEBrzQWYV9ZJYGGsUmEJ6thodVJ1HX64ePQ6Z0qPWP+m+Uq6e9UugrE38jeYsDSMw==", - "dev": true, "dependencies": { "which-typed-array": "^1.1.14" }, @@ -10271,7 +10284,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.0.0.tgz", "integrity": "sha512-d7Uw+eZoloe0EHDIYoe+bQ5WXnGMOpmiZFTuMWCwpjzzkL2nTjcKiAk4hh8TjnGye2TwWOk3UXucZ+3rbmBa8Q==", - "dev": true, "engines": { "node": ">= 0.4" } @@ -10667,7 +10679,6 @@ "version": "1.2.2", "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", - "dev": true, "dependencies": { "define-data-property": "^1.1.4", "es-errors": "^1.3.0", @@ -11412,6 +11423,18 @@ "punycode": "^2.1.0" } }, + "node_modules/util": { + "version": "0.12.5", + "resolved": "https://registry.npmjs.org/util/-/util-0.12.5.tgz", + "integrity": "sha512-kZf/K6hEIrWHI6XqOFUiiMa+79wE/D8Q+NCNAWclkyg3b4d2k7s0QGepNjiABc+aR3N1PAyHL7p6UcLY6LmrnA==", + "dependencies": { + "inherits": "^2.0.3", + "is-arguments": "^1.0.4", + "is-generator-function": "^1.0.7", + "is-typed-array": "^1.1.3", + "which-typed-array": "^1.1.2" + } + }, "node_modules/uuid": { "version": "9.0.1", "resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.1.tgz", @@ -11489,7 +11512,6 @@ "version": "1.1.15", "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.15.tgz", "integrity": "sha512-oV0jmFtUky6CXfkqehVvBP/LSWJ2sy4vWMioiENyJLePrBO/yKyV9OyJySfAKosh+RYkIl5zJCNZ8/4JncrpdA==", - "dev": true, "dependencies": { "available-typed-arrays": "^1.0.7", "call-bind": "^1.0.7", diff --git a/package.json b/package.json index 41e648f2..3b388110 100644 --- a/package.json +++ b/package.json @@ -45,6 +45,7 @@ "js-yaml": "^4.1.0", "lodash": "^4.17.21", "source-map-support": "^0.5.21", + "util": "^0.12.5", "zod": "^3.22.3" }, "lint-staged": { @@ -65,5 +66,6 @@ "test": "test" }, "author": "", - "license": "Apache-2.0" + "license": "Apache-2.0", + "keywords": [] } From 52a0fb951ca9b286de8855e77606938a47f49312 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 13:37:13 -0700 Subject: [PATCH 02/15] update requirements --- lib/rag/layer/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/rag/layer/requirements.txt b/lib/rag/layer/requirements.txt index d35e937e..88694e98 100644 --- a/lib/rag/layer/requirements.txt +++ b/lib/rag/layer/requirements.txt @@ -1,8 +1,8 @@ boto3>=1.34.131 botocore>=1.34.131 -langchain==0.3.0 -langchain-community==0.3.0 -langchain-openai==0.2.4 +langchain==0.2.16 +langchain-community==0.2.17 +langchain-openai==0.1.25 opensearch-py==2.6.0 pgvector==0.2.5 psycopg2-binary==2.9.9 From 15f8ba127827fa3fe71b0ded38a9c7ffc816f12d Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 14:02:14 -0700 Subject: [PATCH 03/15] update requirements --- lib/rag/state_machine/ingest-pipeline.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index fd5a04e5..1d7a38a6 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -37,6 +37,8 @@ import { Rule, Schedule, EventPattern, RuleTargetInput, EventField } from 'aws-c import { SfnStateMachine } from 'aws-cdk-lib/aws-events-targets'; import { RagRepositoryType } from '../../schema'; import * as kms from 'aws-cdk-lib/aws-kms'; +import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; +import { StringParameter } from 'aws-cdk-lib/aws-ssm'; type PipelineConfig = { chunkOverlap: number; @@ -159,6 +161,7 @@ export class IngestPipelineStateMachine extends Construct { } }); + const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/managementKeySecretName`); // Create the ingest documents function with S3 permissions const pipelineIngestDocumentsFunction = new Function(this, 'pipelineIngestDocumentsMapFunc', { runtime: Runtime.PYTHON_3_10, @@ -186,7 +189,7 @@ export class IngestPipelineStateMachine extends Construct { effect: Effect.ALLOW, actions: ['secretsmanager:GetSecretValue'], resources: [ - `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}-lisa-management-key*`, + `${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`, // question marks required to resolve the ARN correctly, `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret*` ] }) From f78c4aef270f8f69d43aee504dce2e4b85ef8c31 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 15:11:27 -0700 Subject: [PATCH 04/15] update requirements --- lib/rag/state_machine/ingest-pipeline.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index 1d7a38a6..8276c1d5 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -190,7 +190,7 @@ export class IngestPipelineStateMachine extends Construct { actions: ['secretsmanager:GetSecretValue'], resources: [ `${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`, // question marks required to resolve the ARN correctly, - `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret*` + `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}-??????` ] }) ] From 57538cdbb0c8d076ad862482f9d20c04cb4aea32 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 16:15:36 -0700 Subject: [PATCH 05/15] update requirements --- lib/rag/state_machine/ingest-pipeline.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index 8276c1d5..e81fee39 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -190,7 +190,7 @@ export class IngestPipelineStateMachine extends Construct { actions: ['secretsmanager:GetSecretValue'], resources: [ `${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`, // question marks required to resolve the ARN correctly, - `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}-??????` + `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret--??????` ] }) ] From 667101ed24d27c4a82527dc630510f722ec59f29 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 16:38:30 -0700 Subject: [PATCH 06/15] secrets manager --- lib/rag/state_machine/ingest-pipeline.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index e81fee39..d4a392b2 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -39,6 +39,7 @@ import { RagRepositoryType } from '../../schema'; import * as kms from 'aws-cdk-lib/aws-kms'; import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; +import { createCdkId } from '../../core/utils'; type PipelineConfig = { chunkOverlap: number; @@ -190,7 +191,7 @@ export class IngestPipelineStateMachine extends Construct { actions: ['secretsmanager:GetSecretValue'], resources: [ `${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`, // question marks required to resolve the ARN correctly, - `arn:aws:secretsmanager:${process.env.CDK_DEFAULT_REGION}:${process.env.CDK_DEFAULT_ACCOUNT}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret--??????` + `${Secret.fromSecretNameV2(this, createCdkId([config.deploymentName, 'RagRDSPwdSecret']), rdsConfig?.passwordSecretId ?? '').secretArn}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret--??????` ] }) ] From 4fa4512bab9d5b1f9348e18f5145c88497f6af86 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 16:38:59 -0700 Subject: [PATCH 07/15] secrets manager --- lib/rag/state_machine/ingest-pipeline.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index d4a392b2..ba413696 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -191,7 +191,7 @@ export class IngestPipelineStateMachine extends Construct { actions: ['secretsmanager:GetSecretValue'], resources: [ `${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`, // question marks required to resolve the ARN correctly, - `${Secret.fromSecretNameV2(this, createCdkId([config.deploymentName, 'RagRDSPwdSecret']), rdsConfig?.passwordSecretId ?? '').secretArn}:secret:${config.deploymentName}LisaRAGPGVectorDBSecret--??????` + `${Secret.fromSecretNameV2(this, createCdkId([config.deploymentName, 'RagRDSPwdSecret']), rdsConfig?.passwordSecretId ?? '').secretArn}-??????` ] }) ] From 8e094273f0790afe595c867bfd35daf8c90fc1f0 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 6 Nov 2024 16:58:00 -0700 Subject: [PATCH 08/15] secrets manager --- lib/rag/state_machine/ingest-pipeline.ts | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index ba413696..78009700 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -189,10 +189,7 @@ export class IngestPipelineStateMachine extends Construct { new PolicyStatement({ effect: Effect.ALLOW, actions: ['secretsmanager:GetSecretValue'], - resources: [ - `${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`, // question marks required to resolve the ARN correctly, - `${Secret.fromSecretNameV2(this, createCdkId([config.deploymentName, 'RagRDSPwdSecret']), rdsConfig?.passwordSecretId ?? '').secretArn}-??????` - ] + resources: ['*'] }) ] }); From 38f4ee6df6ac9c4816df17c4a1089f9783df5cc6 Mon Sep 17 00:00:00 2001 From: Dave Horne <150732144+djhorne-amazon@users.noreply.github.com> Date: Thu, 7 Nov 2024 12:34:25 -0500 Subject: [PATCH 09/15] Feature/rag pipeline patch --- Makefile | 19 ++++++----- README.md | 3 +- example_config.yaml | 4 --- .../repository/pipeline_ingest_documents.py | 31 ++++++++++-------- lib/rag/state_machine/ingest-pipeline.ts | 32 +++++++++---------- package.json | 2 +- test/cdk/mocks/config.yaml | 4 --- test/cdk/stacks/chat.test.ts | 6 ++-- 8 files changed, 47 insertions(+), 54 deletions(-) diff --git a/Makefile b/Makefile index c5129a00..6928b92c 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,8 @@ PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) HEADLESS = false - +DOCKER_CMD := $(CDK_DOCKER) +DOCKER_CMD ?= docker # Arguments defined through command line or config.yaml # PROFILE (optional argument) @@ -150,13 +151,11 @@ installTypeScriptRequirements: ## Make sure Docker is running dockerCheck: - @cmd_output=$$(docker ps); \ - if \ - [ $$? != 0 ]; \ - then \ - echo $$cmd_output; \ - exit 1; \ - fi; \ + @cmd_output=$$(pgrep -f "${DOCKER_CMD}"); \ + if [ $$? != 0 ]; then \ + echo "Process $(DOCKER_CMD) is not running. Exiting..."; \ + exit 1; \ + fi \ ## Check if models are uploaded modelCheck: @@ -229,11 +228,11 @@ cleanMisc: dockerLogin: dockerCheck ifdef PROFILE @$(foreach ACCOUNT,$(ACCOUNT_NUMBERS_ECR), \ - aws ecr get-login-password --region ${REGION} --profile ${PROFILE} | docker login --username AWS --password-stdin $(ACCOUNT).dkr.ecr.${REGION}.${URL_SUFFIX} >/dev/null 2>&1; \ + aws ecr get-login-password --region ${REGION} --profile ${PROFILE} | ${DOCKER_CMD} login --username AWS --password-stdin ${ACCOUNT}.dkr.ecr.${REGION}.${URL_SUFFIX} >/dev/null 2>&1; \ ) else @$(foreach ACCOUNT,$(ACCOUNT_NUMBERS_ECR), \ - aws ecr get-login-password --region ${REGION} | docker login --username AWS --password-stdin $(ACCOUNT).dkr.ecr.${REGION}.${URL_SUFFIX} >/dev/null 2>&1; \ + aws ecr get-login-password --region ${REGION} | ${DOCKER_CMD} login --username AWS --password-stdin ${ACCOUNT}.dkr.ecr.${REGION}.${URL_SUFFIX} >/dev/null 2>&1; \ ) endif diff --git a/README.md b/README.md index 78b4ee0f..2d7cadab 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ Before beginning, ensure you have: 3. Familiarity with AWS Cloud Development Kit (CDK) and infrastructure-as-code principles 4. Python 3.9 or later 5. Node.js 14 or later -6. Docker installed and running +6. Docker/Finch installed and running 7. Sufficient disk space for model downloads and conversions If you're new to CDK, review the [AWS CDK Documentation](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html) and consult with your AWS support team. @@ -235,6 +235,7 @@ Set the following environment variables: export PROFILE=my-aws-profile # Optional, can be left blank export DEPLOYMENT_NAME=my-deployment export ENV=dev # Options: dev, test, or prod +export CDK_DOCKER=finch # Optional, only required if not using docker as container engine ``` --- diff --git a/example_config.yaml b/example_config.yaml index e481b094..1275ab32 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -21,10 +21,6 @@ dev: # rolePrefix: CustomPrefix # policyPrefix: CustomPrefix # instanceProfilePrefix: CustomPrefix - # systemBanner: - # text: 'LISA System' - # backgroundColor: orange - # fontColor: black # vpcId: vpc-0123456789abcdef, # subnetIds: [subnet-fedcba9876543210, subnet-0987654321fedcba], s3BucketModels: hf-models-gaiic diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 5f2c2eb6..9dcdcb22 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -22,11 +22,10 @@ from utilities.common_functions import retry_config from utilities.file_processing import process_record from utilities.validation import ( - safe_error_response, + ValidationError, validate_chunk_params, validate_model_name, validate_repository_type, - ValidationError, ) from utilities.vector_store import get_vector_store_client @@ -70,6 +69,9 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic Returns: Dictionary with status code and response body + + Raises: + Exception: For any error to signal failure to Step Functions """ try: # Get document location from event @@ -151,19 +153,20 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic logger.info(f"Successfully processed {len(all_ids)} chunks from {s3_key} for repository {repository_id}") return { - "statusCode": 200, - "body": { - "message": f"Successfully processed document {s3_key}", - "repository_id": repository_id, - "repository_type": repository_type, - "chunks_processed": len(all_ids), - "document_ids": all_ids, - }, + "message": f"Successfully processed document {s3_key}", + "repository_id": repository_id, + "repository_type": repository_type, + "chunks_processed": len(all_ids), + "document_ids": all_ids, } except ValidationError as e: - # Return 400 for validation errors - return safe_error_response(e) + # For validation errors, raise with clear message + error_msg = f"Validation error: {str(e)}" + logger.error(error_msg) + raise Exception(error_msg) except Exception as e: - # Return 500 for other errors - return safe_error_response(e) + # For all other errors, log and re-raise to signal failure + error_msg = f"Failed to process document: {str(e)}" + logger.error(error_msg, exc_info=True) + raise Exception(error_msg) diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index 78009700..de0abec1 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -196,34 +196,33 @@ export class IngestPipelineStateMachine extends Construct { const pipelineIngestDocumentsMap = new LambdaInvoke(this, 'pipelineIngestDocumentsMap', { lambdaFunction: pipelineIngestDocumentsFunction, - outputPath: OUTPUT_PATH, + retryOnServiceExceptions: true, // Enable retries for service exceptions + resultPath: '$.taskResult' // Store the entire result }); - const handleFailureState = new LambdaInvoke(this, 'HandleFailure', { - lambdaFunction: new Function(this, 'HandleFailureFunc', { - runtime: Runtime.PYTHON_3_10, - handler: 'repository.state_machine.create_model.handle_failure', - code: Code.fromAsset('./lambda'), - timeout: LAMBDA_TIMEOUT, - memorySize: LAMBDA_MEMORY, - vpc: vpc!.vpc, - environment: environment, - environmentEncryption: kmsKey, - layers: layers, - }), - outputPath: OUTPUT_PATH, + const failState = new Fail(this, 'CreateFailed', { + cause: 'Pipeline execution failed', + error: 'States.TaskFailed' }); const successState = new Succeed(this, 'CreateSuccess'); - const failState = new Fail(this, 'CreateFailed'); // Map state for distributed processing with rate limiting const processFiles = new Map(this, 'ProcessFiles', { - maxConcurrency: 5, // Reduced from 10 for better rate limiting + maxConcurrency: 5, itemsPath: '$.files', + resultPath: '$.mapResults' // Store map results in mapResults field }); + + // Configure the iterator without error handling (will be handled at Map level) processFiles.iterator(pipelineIngestDocumentsMap); + // Add error handling at Map level + processFiles.addCatch(failState, { + errors: ['States.ALL'], + resultPath: '$.error' + }); + // Choice state to determine trigger type const triggerChoice = new Choice(this, 'DetermineTriggerType') .when(Condition.stringEquals('$.detail.trigger', 'daily'), listModifiedObjects) @@ -236,7 +235,6 @@ export class IngestPipelineStateMachine extends Construct { listModifiedObjects.next(processFiles); prepareSingleFile.next(processFiles); processFiles.next(successState); - handleFailureState.next(failState); const stateMachine = new StateMachine(this, 'IngestPipeline', { definitionBody: DefinitionBody.fromChainable(definition), diff --git a/package.json b/package.json index 3b388110..f897a9b4 100644 --- a/package.json +++ b/package.json @@ -13,7 +13,7 @@ "migrate-properties": "node ./scripts/migrate-properties.mjs", "postinstall": "(cd lib/user-interface/react && npm install) && (cd lib/docs && npm install)", "postbuild": "(cd lib/user-interface/react && npm build) && (cd lib/docs && npm build)" -}, + }, "devDependencies": { "@aws-cdk/aws-lambda-python-alpha": "2.125.0-alpha.0", "@aws-sdk/client-iam": "^3.490.0", diff --git a/test/cdk/mocks/config.yaml b/test/cdk/mocks/config.yaml index 059f97f7..e0991c48 100644 --- a/test/cdk/mocks/config.yaml +++ b/test/cdk/mocks/config.yaml @@ -21,10 +21,6 @@ dev: # rolePrefix: CustomPrefix # policyPrefix: CustomPrefix # instanceProfilePrefix: CustomPrefix - # systemBanner: - # text: 'LISA System' - # backgroundColor: orange - # fontColor: black s3BucketModels: hf-models-gaiic # aws partition mountS3 package location mountS3DebUrl: https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb diff --git a/test/cdk/stacks/chat.test.ts b/test/cdk/stacks/chat.test.ts index ac766ff2..304adcf0 100644 --- a/test/cdk/stacks/chat.test.ts +++ b/test/cdk/stacks/chat.test.ts @@ -105,12 +105,12 @@ describe.each(regions)('Chat Nag Pack Tests | Region Test: %s', (awsRegion) => { //TODO Update expect values to remediate CDK NAG findings and remove debug test('AwsSolutions CDK NAG Warnings', () => { const warnings = Annotations.fromStack(stack).findWarning('*', Match.stringLikeRegexp('AwsSolutions-.*')); - expect(warnings.length).toBe(1); + expect(warnings.length).toBe(2); }); test('AwsSolutions CDK NAG Errors', () => { const errors = Annotations.fromStack(stack).findError('*', Match.stringLikeRegexp('AwsSolutions-.*')); - expect(errors.length).toBe(17); + expect(errors.length).toBe(28); }); test('NIST800.53r5 CDK NAG Warnings', () => { @@ -120,6 +120,6 @@ describe.each(regions)('Chat Nag Pack Tests | Region Test: %s', (awsRegion) => { test('NIST800.53r5 CDK NAG Errors', () => { const errors = Annotations.fromStack(stack).findError('*', Match.stringLikeRegexp('NIST.*')); - expect(errors.length).toBe(4); + expect(errors.length).toBe(11); }); }); From 05bc60fc61c93d7438c17650b854b8244e702324 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Thu, 7 Nov 2024 10:40:41 -0700 Subject: [PATCH 10/15] Update Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 95e25766..09cedb20 100644 --- a/Makefile +++ b/Makefile @@ -151,7 +151,7 @@ installTypeScriptRequirements: ## Make sure Docker is running dockerCheck: -@cmd_output=$$($(DOCKER_CMD) ps); + @cmd_output=$$($(DOCKER_CMD) ps); \ if [ $$? != 0 ]; then \ echo "Process $(DOCKER_CMD) is not running. Exiting..."; \ exit 1; \ From e86603bc2615a950a187b3b1c2aa5aa4c3ad23db Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Thu, 7 Nov 2024 10:41:21 -0700 Subject: [PATCH 11/15] Update Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 09cedb20..92ec4518 100644 --- a/Makefile +++ b/Makefile @@ -151,7 +151,7 @@ installTypeScriptRequirements: ## Make sure Docker is running dockerCheck: - @cmd_output=$$($(DOCKER_CMD) ps); \ + @cmd_output=$$($(DOCKER_CMD) ps); \ if [ $$? != 0 ]; then \ echo "Process $(DOCKER_CMD) is not running. Exiting..."; \ exit 1; \ From 1bebdf272dd778088fea781a0829c4f5ad46712b Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Tue, 12 Nov 2024 12:24:09 -0700 Subject: [PATCH 12/15] formatting changes --- lambda/__init__.py | 13 +++++++++++++ lambda/repository/pipeline_ingest_documents.py | 11 +++-------- lambda/repository/state_machine/__init__.py | 13 +++++++++++++ .../state_machine/list_modified_objects.py | 2 +- lib/rag/state_machine/ingest-pipeline.ts | 4 ---- 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/lambda/__init__.py b/lambda/__init__.py index e69de29b..4139ae4d 100644 --- a/lambda/__init__.py +++ b/lambda/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 9dcdcb22..816c493d 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -21,12 +21,7 @@ import boto3 from utilities.common_functions import retry_config from utilities.file_processing import process_record -from utilities.validation import ( - ValidationError, - validate_chunk_params, - validate_model_name, - validate_repository_type, -) +from utilities.validation import validate_chunk_params, validate_model_name, validate_repository_type, ValidationError from utilities.vector_store import get_vector_store_client from .lambda_functions import _get_embeddings_pipeline @@ -49,8 +44,8 @@ def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) """ batches = [] for i in range(0, len(texts), batch_size): - text_batch = texts[i: i + batch_size] - metadata_batch = metadatas[i: i + batch_size] + text_batch = texts[i : i + batch_size] + metadata_batch = metadatas[i : i + batch_size] batches.append((text_batch, metadata_batch)) return batches diff --git a/lambda/repository/state_machine/__init__.py b/lambda/repository/state_machine/__init__.py index e69de29b..4139ae4d 100644 --- a/lambda/repository/state_machine/__init__.py +++ b/lambda/repository/state_machine/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/lambda/repository/state_machine/list_modified_objects.py b/lambda/repository/state_machine/list_modified_objects.py index 81228ca7..17ad6003 100644 --- a/lambda/repository/state_machine/list_modified_objects.py +++ b/lambda/repository/state_machine/list_modified_objects.py @@ -74,7 +74,7 @@ def validate_bucket_prefix(bucket: str, prefix: str) -> bool: return True -def handle_list_modified_objects(event: Dict[str, Any], context: Any) -> Dict[str, Any]: +def handle_list_modified_objects(event: Dict[str, Any], context: Any) -> Dict[str, Any] | Any: """ Lists all objects in the specified S3 bucket and prefix that were modified in the last 24 hours. diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index de0abec1..6f231cd7 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -37,9 +37,6 @@ import { Rule, Schedule, EventPattern, RuleTargetInput, EventField } from 'aws-c import { SfnStateMachine } from 'aws-cdk-lib/aws-events-targets'; import { RagRepositoryType } from '../../schema'; import * as kms from 'aws-cdk-lib/aws-kms'; -import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; -import { StringParameter } from 'aws-cdk-lib/aws-ssm'; -import { createCdkId } from '../../core/utils'; type PipelineConfig = { chunkOverlap: number; @@ -162,7 +159,6 @@ export class IngestPipelineStateMachine extends Construct { } }); - const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/managementKeySecretName`); // Create the ingest documents function with S3 permissions const pipelineIngestDocumentsFunction = new Function(this, 'pipelineIngestDocumentsMapFunc', { runtime: Runtime.PYTHON_3_10, From 71f75d675ef1d1dd32bb200de5925b5302dbf844 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Tue, 12 Nov 2024 13:04:33 -0700 Subject: [PATCH 13/15] fix formatting issues --- .pre-commit-config.yaml | 8 ++---- lambda/models/domain_objects.py | 12 ++++---- lambda/models/lambda_functions.py | 28 +++++++++---------- .../src/api/endpoints/v1/embeddings.py | 2 +- .../src/api/endpoints/v1/generation.py | 8 +++--- .../rest-api/src/api/endpoints/v1/models.py | 8 +++--- .../api/endpoints/v2/litellm_passthrough.py | 4 +-- lib/serve/rest-api/src/api/routes.py | 2 +- lib/serve/rest-api/src/handlers/generation.py | 4 +-- .../src/lisa_serve/ecs/textgen/tgi.py | 22 ++++++++------- lib/serve/rest-api/src/main.py | 2 +- .../src/utils/generate_litellm_config.py | 6 ++-- lisa-sdk/lisapy/main.py | 2 +- lisa-sdk/tests/test_client.py | 8 +++--- pyproject.toml | 2 +- test/cdk/stacks/iam-stack.test.ts | 2 +- 16 files changed, 58 insertions(+), 62 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75a65def..f940bb11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,7 @@ repos: name: isort (python) - repo: https://github.com/ambv/black - rev: '23.10.1' + rev: '24.10.0' hooks: - id: black @@ -66,21 +66,19 @@ repos: args: [--exit-non-zero-on-fix] - repo: https://github.com/pycqa/flake8 - rev: '6.1.0' + rev: '7.1.1' hooks: - id: flake8 additional_dependencies: - - flake8-docstrings - flake8-broken-line - flake8-bugbear - flake8-comprehensions - flake8-debugger - flake8-string-format args: - - --docstring-convention=numpy - --max-line-length=120 - --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends - - --ignore=B008 # Ignore error for function calls in argument defaults + - --ignore=B008,E203 # Ignore error for function calls in argument defaults exclude: ^(__init__.py$|.*\/__init__.py$) diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index dfa7143e..8aa8c048 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -98,7 +98,7 @@ class AutoScalingConfig(BaseModel): defaultInstanceWarmup: PositiveInt metricConfig: MetricConfig - @model_validator(mode="after") # type: ignore + @model_validator(mode="after") def validate_auto_scaling_config(self) -> Self: """Validate autoScalingConfig values.""" if self.minCapacity > self.maxCapacity: @@ -115,7 +115,7 @@ class AutoScalingInstanceConfig(BaseModel): maxCapacity: Optional[PositiveInt] = None desiredCapacity: Optional[PositiveInt] = None - @model_validator(mode="after") # type: ignore + @model_validator(mode="after") def validate_auto_scaling_instance_config(self) -> Self: """Validate autoScalingInstanceConfig values.""" config_fields = [self.minCapacity, self.maxCapacity, self.desiredCapacity] @@ -155,7 +155,7 @@ class ContainerConfig(BaseModel): healthCheckConfig: ContainerHealthCheckConfig environment: Optional[Dict[str, str]] = {} - @field_validator("environment") # type: ignore + @field_validator("environment") @classmethod def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]: """Validate that all keys in Dict are not empty.""" @@ -201,7 +201,7 @@ class CreateModelRequest(BaseModel): modelUrl: Optional[str] = None streaming: Optional[bool] = False - @model_validator(mode="after") # type: ignore + @model_validator(mode="after") def validate_create_model_request(self) -> Self: """Validate whole request object.""" # Validate that an embedding model cannot be set as streaming-enabled @@ -252,7 +252,7 @@ class UpdateModelRequest(BaseModel): modelType: Optional[ModelType] = None streaming: Optional[bool] = None - @model_validator(mode="after") # type: ignore + @model_validator(mode="after") def validate_update_model_request(self) -> Self: """Validate whole request object.""" fields = [ @@ -273,7 +273,7 @@ def validate_update_model_request(self) -> Self: raise ValueError("Embedding model cannot be set with streaming enabled.") return self - @field_validator("autoScalingInstanceConfig") # type: ignore + @field_validator("autoScalingInstanceConfig") @classmethod def validate_autoscaling_instance_config(cls, config: AutoScalingInstanceConfig) -> AutoScalingInstanceConfig: """Validate that the AutoScaling instance config has at least one positive value.""" diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index 06aa0dce..e9ad299c 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -59,23 +59,23 @@ stepfunctions = boto3.client("stepfunctions", region_name=os.environ["AWS_REGION"], config=retry_config) -@app.exception_handler(ModelNotFoundError) # type: ignore +@app.exception_handler(ModelNotFoundError) async def model_not_found_handler(request: Request, exc: ModelNotFoundError) -> JSONResponse: """Handle exception when model cannot be found and translate to a 404 error.""" return JSONResponse(status_code=404, content={"message": str(exc)}) -@app.exception_handler(RequestValidationError) # type: ignore -async def validation_exception_handler(request: Request, exc: RequestValidationError): +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: """Handle exception when request fails validation and and translate to a 422 error.""" return JSONResponse( status_code=422, content={"detail": jsonable_encoder(exc.errors()), "type": "RequestValidationError"} ) -@app.exception_handler(InvalidStateTransitionError) # type: ignore -@app.exception_handler(ModelAlreadyExistsError) # type: ignore -@app.exception_handler(ValueError) # type: ignore +@app.exception_handler(InvalidStateTransitionError) +@app.exception_handler(ModelAlreadyExistsError) +@app.exception_handler(ValueError) async def user_error_handler( request: Request, exc: Union[InvalidStateTransitionError, ModelAlreadyExistsError, ValueError] ) -> JSONResponse: @@ -83,8 +83,8 @@ async def user_error_handler( return JSONResponse(status_code=400, content={"message": str(exc)}) -@app.post(path="", include_in_schema=False) # type: ignore -@app.post(path="/") # type: ignore +@app.post(path="", include_in_schema=False) +@app.post(path="/") async def create_model(create_request: CreateModelRequest) -> CreateModelResponse: """Endpoint to create a model.""" create_handler = CreateModelHandler( @@ -95,8 +95,8 @@ async def create_model(create_request: CreateModelRequest) -> CreateModelRespons return create_handler(create_request=create_request) -@app.get(path="", include_in_schema=False) # type: ignore -@app.get(path="/") # type: ignore +@app.get(path="", include_in_schema=False) +@app.get(path="/") async def list_models() -> ListModelsResponse: """Endpoint to list models.""" list_handler = ListModelsHandler( @@ -107,7 +107,7 @@ async def list_models() -> ListModelsResponse: return list_handler() -@app.get(path="/{model_id}") # type: ignore +@app.get(path="/{model_id}") async def get_model( model_id: Annotated[str, Path(title="The unique model ID of the model to get")], request: Request ) -> GetModelResponse: @@ -120,7 +120,7 @@ async def get_model( return get_handler(model_id=model_id) -@app.put(path="/{model_id}") # type: ignore +@app.put(path="/{model_id}") async def update_model( model_id: Annotated[str, Path(title="The unique model ID of the model to update")], update_request: UpdateModelRequest, @@ -134,7 +134,7 @@ async def update_model( return update_handler(model_id=model_id, update_request=update_request) -@app.delete(path="/{model_id}") # type: ignore +@app.delete(path="/{model_id}") async def delete_model( model_id: Annotated[str, Path(title="The unique model ID of the model to delete")], request: Request ) -> DeleteModelResponse: @@ -147,7 +147,7 @@ async def delete_model( return delete_handler(model_id=model_id) -@app.get(path="/metadata/instances") # type: ignore +@app.get(path="/metadata/instances") async def get_instances() -> list[str]: """Endpoint to list available instances in this region.""" return list(sess.get_service_model("ec2").shape_for("InstanceType").enum) diff --git a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py index 6cf1b591..163e8303 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/embeddings.py @@ -27,7 +27,7 @@ router = APIRouter() -@router.post(f"/{RestApiResource.EMBEDDINGS.value}") # type: ignore +@router.post(f"/{RestApiResource.EMBEDDINGS.value}") async def embeddings(request: EmbeddingsRequest) -> JSONResponse: """Text embeddings.""" response = await handle_embeddings(request.dict()) diff --git a/lib/serve/rest-api/src/api/endpoints/v1/generation.py b/lib/serve/rest-api/src/api/endpoints/v1/generation.py index b7a21405..6413035c 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/generation.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/generation.py @@ -33,7 +33,7 @@ router = APIRouter() -@router.post(f"/{RestApiResource.GENERATE.value}") # type: ignore +@router.post(f"/{RestApiResource.GENERATE.value}") async def generate(request: GenerateRequest) -> JSONResponse: """Text generation.""" response = await handle_generate(request.dict()) @@ -41,7 +41,7 @@ async def generate(request: GenerateRequest) -> JSONResponse: return JSONResponse(content=response, status_code=200) -@router.post(f"/{RestApiResource.GENERATE_STREAM.value}") # type: ignore +@router.post(f"/{RestApiResource.GENERATE_STREAM.value}") async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( @@ -50,7 +50,7 @@ async def generate_stream(request: GenerateStreamRequest) -> StreamingResponse: ) -@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS.value}") # type: ignore +@router.post(f"/{RestApiResource.OPENAI_CHAT_COMPLETIONS.value}") async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( @@ -59,7 +59,7 @@ async def openai_chat_completion_generate_stream(request: OpenAIChatCompletionsR ) -@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS.value}") # type: ignore +@router.post(f"/{RestApiResource.OPENAI_COMPLETIONS.value}") async def openai_completion_generate_stream(request: OpenAICompletionsRequest) -> StreamingResponse: """Text generation with streaming.""" return StreamingResponse( diff --git a/lib/serve/rest-api/src/api/endpoints/v1/models.py b/lib/serve/rest-api/src/api/endpoints/v1/models.py index 35eab3a8..e3d37455 100644 --- a/lib/serve/rest-api/src/api/endpoints/v1/models.py +++ b/lib/serve/rest-api/src/api/endpoints/v1/models.py @@ -33,7 +33,7 @@ router = APIRouter() -@router.get(f"/{RestApiResource.DESCRIBE_MODEL.value}") # type: ignore +@router.get(f"/{RestApiResource.DESCRIBE_MODEL.value}") async def describe_model( provider: str = Query( None, @@ -52,7 +52,7 @@ async def describe_model( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.DESCRIBE_MODELS.value}") # type: ignore +@router.get(f"/{RestApiResource.DESCRIBE_MODELS.value}") async def describe_models( model_types: Optional[List[ModelType]] = Query( None, @@ -69,7 +69,7 @@ async def describe_models( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.LIST_MODELS.value}") # type: ignore +@router.get(f"/{RestApiResource.LIST_MODELS.value}") async def list_models( model_types: Optional[List[ModelType]] = Query( None, @@ -86,7 +86,7 @@ async def list_models( return JSONResponse(content=response, status_code=200) -@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS.value}") # type: ignore +@router.get(f"/{RestApiResource.OPENAI_LIST_MODELS.value}") async def openai_list_models() -> JSONResponse: """List models for OpenAI Compatibility. Only returns TEXTGEN models.""" response = await handle_openai_list_models() diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index b98aaacf..8baba5c3 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -82,9 +82,7 @@ def generate_response(iterator: Iterator[Union[str, bytes]]) -> Iterator[str]: yield f"{line}\n\n" -@router.api_route( - "/{api_path:path}", methods=["GET", "POST", "OPTIONS", "PUT", "PATCH", "DELETE", "HEAD"] -) # type: ignore +@router.api_route("/{api_path:path}", methods=["GET", "POST", "OPTIONS", "PUT", "PATCH", "DELETE", "HEAD"]) async def litellm_passthrough(request: Request, api_path: str) -> Response: """ Pass requests directly to LiteLLM. LiteLLM and deployed models will respond here directly. diff --git a/lib/serve/rest-api/src/api/routes.py b/lib/serve/rest-api/src/api/routes.py index f0e3f410..ca79631a 100644 --- a/lib/serve/rest-api/src/api/routes.py +++ b/lib/serve/rest-api/src/api/routes.py @@ -40,7 +40,7 @@ ) -@router.get("/health") # type: ignore +@router.get("/health") async def health_check() -> JSONResponse: """Health check path. diff --git a/lib/serve/rest-api/src/handlers/generation.py b/lib/serve/rest-api/src/handlers/generation.py index 1e05d12e..bf35adb8 100644 --- a/lib/serve/rest-api/src/handlers/generation.py +++ b/lib/serve/rest-api/src/handlers/generation.py @@ -31,7 +31,7 @@ async def handle_generate(request_data: Dict[str, Any]) -> Dict[str, Any]: return response.dict() # type: ignore -@handle_stream_exceptions # type: ignore +@handle_stream_exceptions async def handle_generate_stream(request_data: Dict[str, Any]) -> AsyncGenerator[str, None]: """Handle for generate_stream endpoint.""" model, model_kwargs, text = await validate_and_prepare_llm_request(request_data, RestApiResource.GENERATE_STREAM) @@ -57,7 +57,7 @@ def parse_model_provider_names(model_string: str) -> Tuple[str, str]: return model_name, provider -@handle_stream_exceptions # type: ignore +@handle_stream_exceptions async def handle_openai_generate_stream( request_data: Dict[str, Any], is_text_completion: bool = False ) -> AsyncGenerator[str, None]: diff --git a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py index 167bc1af..a1415be9 100644 --- a/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py +++ b/lib/serve/rest-api/src/lisa_serve/ecs/textgen/tgi.py @@ -211,16 +211,18 @@ async def openai_generate_stream( object="text_completion" if is_text_completion else "chat.completion.chunk", system_fingerprint=fingerprint, choices=[ - OpenAICompletionsChoice( - index=0, - finish_reason=resp.details.finish_reason if resp.details else None, - text=resp.token.text, - ) - if is_text_completion - else OpenAIChatCompletionsChoice( - index=0, - finish_reason=resp.details.finish_reason if resp.details else None, - delta=OpenAIChatCompletionsDelta(content=resp.token.text, role="assistant"), + ( + OpenAICompletionsChoice( + index=0, + finish_reason=resp.details.finish_reason if resp.details else None, + text=resp.token.text, + ) + if is_text_completion + else OpenAIChatCompletionsChoice( + index=0, + finish_reason=resp.details.finish_reason if resp.details else None, + delta=OpenAIChatCompletionsDelta(content=resp.token.text, role="assistant"), + ) ) ], ) diff --git a/lib/serve/rest-api/src/main.py b/lib/serve/rest-api/src/main.py index 603120aa..e28e901d 100644 --- a/lib/serve/rest-api/src/main.py +++ b/lib/serve/rest-api/src/main.py @@ -147,7 +147,7 @@ async def lifespan(app: FastAPI): # type: ignore ############## -@app.middleware("http") # type: ignore +@app.middleware("http") async def process_request(request: Request, call_next: Any) -> Any: """Middleware for processing all HTTP requests.""" event = "process_request" diff --git a/lib/serve/rest-api/src/utils/generate_litellm_config.py b/lib/serve/rest-api/src/utils/generate_litellm_config.py index 9e5bec09..9ced3150 100644 --- a/lib/serve/rest-api/src/utils/generate_litellm_config.py +++ b/lib/serve/rest-api/src/utils/generate_litellm_config.py @@ -25,10 +25,8 @@ secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"]) -@click.command() # type: ignore -@click.option( - "-f", "--filepath", type=click.Path(exists=True, file_okay=True, dir_okay=False, writable=True) -) # type: ignore +@click.command() +@click.option("-f", "--filepath", type=click.Path(exists=True, file_okay=True, dir_okay=False, writable=True)) def generate_config(filepath: str) -> None: """Read LiteLLM configuration and rewrite it with LISA-deployed model information.""" with open(filepath, "r") as fp: diff --git a/lisa-sdk/lisapy/main.py b/lisa-sdk/lisapy/main.py index 8c8ea176..af177b58 100644 --- a/lisa-sdk/lisapy/main.py +++ b/lisa-sdk/lisapy/main.py @@ -51,7 +51,7 @@ class Lisa(BaseModel): _session: Session - @field_validator("url") # type: ignore + @field_validator("url") def validate_url(cls: "Lisa", v: str) -> str: """Validate URL is properly formatted.""" url = v.rstrip("/") diff --git a/lisa-sdk/tests/test_client.py b/lisa-sdk/tests/test_client.py index c64419aa..e3758497 100644 --- a/lisa-sdk/tests/test_client.py +++ b/lisa-sdk/tests/test_client.py @@ -22,13 +22,13 @@ from lisapy.types import ModelKwargs, ModelType -@pytest.fixture(scope="session") # type: ignore +@pytest.fixture(scope="session") def url(pytestconfig: pytest.Config) -> Any: """Get the url argument.""" return pytestconfig.getoption("url") -@pytest.fixture(scope="session") # type: ignore +@pytest.fixture(scope="session") def verify(pytestconfig: pytest.Config) -> Union[bool, Any]: """Get the verify argument.""" if pytestconfig.getoption("verify") == "false": @@ -114,7 +114,7 @@ def test_generate_stream(url: str, verify: Union[bool, str]) -> None: assert response.generated_tokens == 1 -@pytest.mark.asyncio # type: ignore +@pytest.mark.asyncio async def test_generate_async(url: str, verify: Union[bool, str]) -> None: """Generates a batch async response from a textgen.tgi model.""" client = Lisa(url=url, verify=verify) @@ -127,7 +127,7 @@ async def test_generate_async(url: str, verify: Union[bool, str]) -> None: assert response.generated_tokens == 1 -@pytest.mark.asyncio # type: ignore +@pytest.mark.asyncio async def test_generate_stream_async(url: str, verify: Union[bool, str]) -> None: """Generates a streaming async response from a textgen.tgi model.""" client = Lisa(url=url, verify=verify) diff --git a/pyproject.toml b/pyproject.toml index da198d79..4930071a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ skip_glob = [ [tool.mypy] ignore_missing_imports = true disallow_untyped_defs = true -disallow_untyped_decorators = true +disallow_untyped_decorators = false disallow_incomplete_defs = true disallow_any_unimported = false no_implicit_optional = true diff --git a/test/cdk/stacks/iam-stack.test.ts b/test/cdk/stacks/iam-stack.test.ts index 2f766e94..a4aac9d6 100644 --- a/test/cdk/stacks/iam-stack.test.ts +++ b/test/cdk/stacks/iam-stack.test.ts @@ -89,7 +89,7 @@ describe.each(regions)('IAM Stack CDK Nag Tests | Region Test: %s', (awsRegion: test('AwsSolutions CDK NAG Errors', () => { const errors = Annotations.fromStack(stack).findError('*', Match.stringLikeRegexp('AwsSolutions-.*')); - expect(errors.length).toBe(12); + expect(errors.length).toBe(14); }); test('NIST800.53r5 CDK NAG Warnings', () => { From 6bf4a0381cab156abb4a27076a1276f950889862 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Tue, 12 Nov 2024 13:06:15 -0700 Subject: [PATCH 14/15] npm install --- package-lock.json | 1 + 1 file changed, 1 insertion(+) diff --git a/package-lock.json b/package-lock.json index 948445c9..0ac5490a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,6 +17,7 @@ "js-yaml": "^4.1.0", "lodash": "^4.17.21", "source-map-support": "^0.5.21", + "util": "^0.12.5", "zod": "^3.22.3" }, "bin": { From 6b608bd3f614c9f2085bbf05b545095c7c28af2b Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Wed, 20 Nov 2024 10:43:42 -0700 Subject: [PATCH 15/15] pre-commit --- lambda/repository/lambda_functions.py | 198 +++++++++++++------------- lib/rag/index.ts | 2 +- lib/schema.ts | 2 +- lib/serve/rest-api/Dockerfile | 2 +- 4 files changed, 102 insertions(+), 102 deletions(-) diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index 3bf624e1..8bde4ab0 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -62,6 +62,105 @@ def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: ) return embedding + # Create embeddings client that matches LisaOpenAIEmbeddings interface + + +class PipelineEmbeddings: + def __init__(self) -> None: + try: + # Get the management key secret name from SSM Parameter Store + secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) + secret_name = secret_name_param["Parameter"]["Value"] + + # Get the management token from Secrets Manager using the secret name + secret_response = secrets_client.get_secret_value(SecretId=secret_name) + self.token = secret_response["SecretString"] + + # Get the API endpoint from SSM + lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) + self.base_url = f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" + + # Get certificate path for SSL verification + self.cert_path = get_cert_path(iam_client) + + logger.info("Successfully initialized pipeline embeddings") + except Exception: + logger.error("Failed to initialize pipeline embeddings", exc_info=True) + raise + + def embed_documents(self, texts: List[str], model_name: str) -> List[List[float]]: + if not texts: + raise ValidationError("No texts provided for embedding") + + logger.info(f"Embedding {len(texts)} documents") + try: + url = f"{self.base_url}/embeddings" + request_data = {"input": texts, "model": model_name} + + response = requests.post( + url, + json=request_data, + headers={"Authorization": self.token, "Content-Type": "application/json"}, + verify=self.cert_path, # Use proper SSL verification + timeout=300, # 5 minute timeout + ) + + if response.status_code != 200: + logger.error(f"Embedding request failed with status {response.status_code}") + logger.error(f"Response content: {response.text}") + raise Exception(f"Embedding request failed with status {response.status_code}") + + result = response.json() + logger.debug(f"API Response: {result}") # Log the full response for debugging + + # Handle different response formats + embeddings = [] + if isinstance(result, dict): + if "data" in result: + # OpenAI-style format + for item in result["data"]: + if isinstance(item, dict) and "embedding" in item: + embeddings.append(item["embedding"]) + else: + embeddings.append(item) # Assume the item itself is the embedding + else: + # Try to find embeddings in the response + for key in ["embeddings", "embedding", "vectors", "vector"]: + if key in result: + embeddings = result[key] + break + elif isinstance(result, list): + # Direct list format + embeddings = result + + if not embeddings: + logger.error(f"Could not find embeddings in response: {result}") + raise Exception("No embeddings found in API response") + + if len(embeddings) != len(texts): + logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})") + raise Exception("Number of embeddings does not match number of input texts") + + logger.info(f"Successfully embedded {len(texts)} documents") + return embeddings + + except requests.Timeout: + logger.error("Embedding request timed out") + raise Exception("Embedding request timed out after 5 minutes") + except requests.RequestException as e: + logger.error(f"Request failed: {str(e)}", exc_info=True) + raise + except Exception as e: + logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True) + raise + + def embed_query(self, text: str, model_name: str) -> List[float]: + if not text or not isinstance(text, str): + raise ValidationError("Invalid query text") + + logger.info("Embedding single query text") + return self.embed_documents([text], model_name)[0] + def _get_embeddings_pipeline(model_name: str) -> Any: """ @@ -77,105 +176,6 @@ def _get_embeddings_pipeline(model_name: str) -> Any: logger.info("Starting pipeline embeddings request") validate_model_name(model_name) - # Create embeddings client that matches LisaOpenAIEmbeddings interface - class PipelineEmbeddings: - def __init__(self) -> None: - try: - # Get the management key secret name from SSM Parameter Store - secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"]) - secret_name = secret_name_param["Parameter"]["Value"] - - # Get the management token from Secrets Manager using the secret name - secret_response = secrets_client.get_secret_value(SecretId=secret_name) - self.token = secret_response["SecretString"] - - # Get the API endpoint from SSM - lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"]) - self.base_url = ( - f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve" - ) - - # Get certificate path for SSL verification - self.cert_path = get_cert_path(iam_client) - - logger.info("Successfully initialized pipeline embeddings") - except Exception: - logger.error("Failed to initialize pipeline embeddings", exc_info=True) - raise - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - if not texts: - raise ValidationError("No texts provided for embedding") - - logger.info(f"Embedding {len(texts)} documents") - try: - url = f"{self.base_url}/embeddings" - request_data = {"input": texts, "model": model_name} - - response = requests.post( - url, - json=request_data, - headers={"Authorization": self.token, "Content-Type": "application/json"}, - verify=self.cert_path, # Use proper SSL verification - timeout=300, # 5 minute timeout - ) - - if response.status_code != 200: - logger.error(f"Embedding request failed with status {response.status_code}") - logger.error(f"Response content: {response.text}") - raise Exception(f"Embedding request failed with status {response.status_code}") - - result = response.json() - logger.debug(f"API Response: {result}") # Log the full response for debugging - - # Handle different response formats - embeddings = [] - if isinstance(result, dict): - if "data" in result: - # OpenAI-style format - for item in result["data"]: - if isinstance(item, dict) and "embedding" in item: - embeddings.append(item["embedding"]) - else: - embeddings.append(item) # Assume the item itself is the embedding - else: - # Try to find embeddings in the response - for key in ["embeddings", "embedding", "vectors", "vector"]: - if key in result: - embeddings = result[key] - break - elif isinstance(result, list): - # Direct list format - embeddings = result - - if not embeddings: - logger.error(f"Could not find embeddings in response: {result}") - raise Exception("No embeddings found in API response") - - if len(embeddings) != len(texts): - logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})") - raise Exception("Number of embeddings does not match number of input texts") - - logger.info(f"Successfully embedded {len(texts)} documents") - return embeddings - - except requests.Timeout: - logger.error("Embedding request timed out") - raise Exception("Embedding request timed out after 5 minutes") - except requests.RequestException as e: - logger.error(f"Request failed: {str(e)}", exc_info=True) - raise - except Exception as e: - logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True) - raise - - def embed_query(self, text: str) -> List[float]: - if not text or not isinstance(text, str): - raise ValidationError("Invalid query text") - - logger.info("Embedding single query text") - return self.embed_documents([text])[0] - return PipelineEmbeddings() diff --git a/lib/rag/index.ts b/lib/rag/index.ts index cbfd831b..cfe1bbc0 100644 --- a/lib/rag/index.ts +++ b/lib/rag/index.ts @@ -203,7 +203,7 @@ export class LisaRagStack extends Stack { version: EngineVersion.OPENSEARCH_2_9, enableVersionUpgrade: true, vpc: vpc.vpc, - ...vpc.subnetSelection ? {vpcSubnets: [vpc.subnetSelection]} : {}, + ...(vpc.subnetSelection && {vpcSubnets: [vpc.subnetSelection]}), ebs: { enabled: true, volumeSize: ragConfig.opensearchConfig.volumeSize, diff --git a/lib/schema.ts b/lib/schema.ts index ad76342e..967e0620 100644 --- a/lib/schema.ts +++ b/lib/schema.ts @@ -467,7 +467,7 @@ const RagRepositoryConfigSchema = z s3Prefix: z.string(), trigger: z.union([z.literal('daily'), z.literal('event')]), collectionName: z.string() - })).optional(), + })).optional().describe('Rag ingestion pipeline for automated inclusion into a vector store from S3'), }) .refine((input) => { return !((input.type === RagRepositoryType.OPENSEARCH && input.opensearchConfig === undefined) || diff --git a/lib/serve/rest-api/Dockerfile b/lib/serve/rest-api/Dockerfile index f64fd690..8a2953bf 100644 --- a/lib/serve/rest-api/Dockerfile +++ b/lib/serve/rest-api/Dockerfile @@ -1,6 +1,6 @@ # Use an argument for the base image ARG BASE_IMAGE -FROM --platform=linux/amd64 ${BASE_IMAGE} +FROM ${BASE_IMAGE} # Copy LiteLLM config directly out of the LISA config.yaml file ARG LITELLM_CONFIG