Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rag pipeline #180

Merged
merged 21 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
name: isort (python)

- repo: https://github.com/ambv/black
rev: '23.10.1'
rev: '24.10.0'
hooks:
- id: black

Expand All @@ -66,21 +66,19 @@ repos:
args: [--exit-non-zero-on-fix]

- repo: https://github.com/pycqa/flake8
rev: '6.1.0'
rev: '7.1.1'
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings
- flake8-broken-line
- flake8-bugbear
- flake8-comprehensions
- flake8-debugger
- flake8-string-format
args:
- --docstring-convention=numpy
- --max-line-length=120
- --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends
- --ignore=B008 # Ignore error for function calls in argument defaults
- --ignore=B008,E203 # Ignore error for function calls in argument defaults
exclude: ^(__init__.py$|.*\/__init__.py$)


Expand Down
13 changes: 13 additions & 0 deletions lambda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
12 changes: 6 additions & 6 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class AutoScalingConfig(BaseModel):
defaultInstanceWarmup: PositiveInt
metricConfig: MetricConfig

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_auto_scaling_config(self) -> Self:
"""Validate autoScalingConfig values."""
if self.minCapacity > self.maxCapacity:
Expand All @@ -115,7 +115,7 @@ class AutoScalingInstanceConfig(BaseModel):
maxCapacity: Optional[PositiveInt] = None
desiredCapacity: Optional[PositiveInt] = None

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_auto_scaling_instance_config(self) -> Self:
"""Validate autoScalingInstanceConfig values."""
config_fields = [self.minCapacity, self.maxCapacity, self.desiredCapacity]
Expand Down Expand Up @@ -155,7 +155,7 @@ class ContainerConfig(BaseModel):
healthCheckConfig: ContainerHealthCheckConfig
environment: Optional[Dict[str, str]] = {}

@field_validator("environment") # type: ignore
@field_validator("environment")
@classmethod
def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]:
"""Validate that all keys in Dict are not empty."""
Expand Down Expand Up @@ -201,7 +201,7 @@ class CreateModelRequest(BaseModel):
modelUrl: Optional[str] = None
streaming: Optional[bool] = False

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_create_model_request(self) -> Self:
"""Validate whole request object."""
# Validate that an embedding model cannot be set as streaming-enabled
Expand Down Expand Up @@ -252,7 +252,7 @@ class UpdateModelRequest(BaseModel):
modelType: Optional[ModelType] = None
streaming: Optional[bool] = None

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_update_model_request(self) -> Self:
"""Validate whole request object."""
fields = [
Expand All @@ -273,7 +273,7 @@ def validate_update_model_request(self) -> Self:
raise ValueError("Embedding model cannot be set with streaming enabled.")
return self

@field_validator("autoScalingInstanceConfig") # type: ignore
@field_validator("autoScalingInstanceConfig")
@classmethod
def validate_autoscaling_instance_config(cls, config: AutoScalingInstanceConfig) -> AutoScalingInstanceConfig:
"""Validate that the AutoScaling instance config has at least one positive value."""
Expand Down
28 changes: 14 additions & 14 deletions lambda/models/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,32 @@
stepfunctions = boto3.client("stepfunctions", region_name=os.environ["AWS_REGION"], config=retry_config)


@app.exception_handler(ModelNotFoundError) # type: ignore
@app.exception_handler(ModelNotFoundError)
async def model_not_found_handler(request: Request, exc: ModelNotFoundError) -> JSONResponse:
"""Handle exception when model cannot be found and translate to a 404 error."""
return JSONResponse(status_code=404, content={"message": str(exc)})


@app.exception_handler(RequestValidationError) # type: ignore
async def validation_exception_handler(request: Request, exc: RequestValidationError):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle exception when request fails validation and and translate to a 422 error."""
return JSONResponse(
status_code=422, content={"detail": jsonable_encoder(exc.errors()), "type": "RequestValidationError"}
)


@app.exception_handler(InvalidStateTransitionError) # type: ignore
@app.exception_handler(ModelAlreadyExistsError) # type: ignore
@app.exception_handler(ValueError) # type: ignore
@app.exception_handler(InvalidStateTransitionError)
@app.exception_handler(ModelAlreadyExistsError)
@app.exception_handler(ValueError)
async def user_error_handler(
request: Request, exc: Union[InvalidStateTransitionError, ModelAlreadyExistsError, ValueError]
) -> JSONResponse:
"""Handle errors when customer requests options that cannot be processed."""
return JSONResponse(status_code=400, content={"message": str(exc)})


@app.post(path="", include_in_schema=False) # type: ignore
@app.post(path="/") # type: ignore
@app.post(path="", include_in_schema=False)
@app.post(path="/")
async def create_model(create_request: CreateModelRequest) -> CreateModelResponse:
"""Endpoint to create a model."""
create_handler = CreateModelHandler(
Expand All @@ -95,8 +95,8 @@ async def create_model(create_request: CreateModelRequest) -> CreateModelRespons
return create_handler(create_request=create_request)


@app.get(path="", include_in_schema=False) # type: ignore
@app.get(path="/") # type: ignore
@app.get(path="", include_in_schema=False)
@app.get(path="/")
async def list_models() -> ListModelsResponse:
"""Endpoint to list models."""
list_handler = ListModelsHandler(
Expand All @@ -107,7 +107,7 @@ async def list_models() -> ListModelsResponse:
return list_handler()


@app.get(path="/{model_id}") # type: ignore
@app.get(path="/{model_id}")
async def get_model(
model_id: Annotated[str, Path(title="The unique model ID of the model to get")], request: Request
) -> GetModelResponse:
Expand All @@ -120,7 +120,7 @@ async def get_model(
return get_handler(model_id=model_id)


@app.put(path="/{model_id}") # type: ignore
@app.put(path="/{model_id}")
async def update_model(
model_id: Annotated[str, Path(title="The unique model ID of the model to update")],
update_request: UpdateModelRequest,
Expand All @@ -134,7 +134,7 @@ async def update_model(
return update_handler(model_id=model_id, update_request=update_request)


@app.delete(path="/{model_id}") # type: ignore
@app.delete(path="/{model_id}")
async def delete_model(
model_id: Annotated[str, Path(title="The unique model ID of the model to delete")], request: Request
) -> DeleteModelResponse:
Expand All @@ -147,7 +147,7 @@ async def delete_model(
return delete_handler(model_id=model_id)


@app.get(path="/metadata/instances") # type: ignore
@app.get(path="/metadata/instances")
async def get_instances() -> list[str]:
"""Endpoint to list available instances in this region."""
return list(sess.get_service_model("ec2").shape_for("InstanceType").enum)
Expand Down
126 changes: 122 additions & 4 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
from typing import Any, Dict, List

import boto3
import create_env_variables # noqa: F401
import requests
from botocore.config import Config
from lisapy.langchain import LisaOpenAIEmbeddings
from lisapy.utils import get_cert_path
from utilities.common_functions import api_wrapper, get_id_token, retry_config
from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, retry_config
from utilities.file_processing import process_record
from utilities.validation import validate_model_name, ValidationError
from utilities.vector_store import get_vector_store_client

logger = logging.getLogger(__name__)
session = boto3.Session()
ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config)
secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config)
iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config)
s3 = session.client(
"s3",
Expand All @@ -54,13 +55,130 @@ def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"]

base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve"
cert_path = get_cert_path(iam_client)

embedding = LisaOpenAIEmbeddings(
lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=get_cert_path(iam_client)
lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=cert_path
)
return embedding


def _get_embeddings_pipeline(model_name: str) -> Any:
"""
Get embeddings for pipeline requests using management token.

Args:
model_name: Name of the embedding model to use

Raises:
ValidationError: If model name is invalid
Exception: If API request fails
"""
logger.info("Starting pipeline embeddings request")
validate_model_name(model_name)

# Create embeddings client that matches LisaOpenAIEmbeddings interface
class PipelineEmbeddings:
estohlmann marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
try:
# Get the management key secret name from SSM Parameter Store
secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"])
secret_name = secret_name_param["Parameter"]["Value"]

# Get the management token from Secrets Manager using the secret name
secret_response = secrets_client.get_secret_value(SecretId=secret_name)
self.token = secret_response["SecretString"]

# Get the API endpoint from SSM
lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bad idea to add some error handling around missing environment variables

self.base_url = (
f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve"
)

# Get certificate path for SSL verification
self.cert_path = get_cert_path(iam_client)

logger.info("Successfully initialized pipeline embeddings")
except Exception:
logger.error("Failed to initialize pipeline embeddings", exc_info=True)
raise

def embed_documents(self, texts: List[str]) -> List[List[float]]:
if not texts:
raise ValidationError("No texts provided for embedding")

logger.info(f"Embedding {len(texts)} documents")
try:
url = f"{self.base_url}/embeddings"
request_data = {"input": texts, "model": model_name}

response = requests.post(
url,
json=request_data,
headers={"Authorization": self.token, "Content-Type": "application/json"},
verify=self.cert_path, # Use proper SSL verification
timeout=300, # 5 minute timeout
)

if response.status_code != 200:
logger.error(f"Embedding request failed with status {response.status_code}")
logger.error(f"Response content: {response.text}")
raise Exception(f"Embedding request failed with status {response.status_code}")

result = response.json()
logger.debug(f"API Response: {result}") # Log the full response for debugging

# Handle different response formats
embeddings = []
if isinstance(result, dict):
if "data" in result:
# OpenAI-style format
for item in result["data"]:
if isinstance(item, dict) and "embedding" in item:
embeddings.append(item["embedding"])
else:
embeddings.append(item) # Assume the item itself is the embedding
else:
# Try to find embeddings in the response
for key in ["embeddings", "embedding", "vectors", "vector"]:
if key in result:
embeddings = result[key]
break
elif isinstance(result, list):
# Direct list format
embeddings = result

if not embeddings:
logger.error(f"Could not find embeddings in response: {result}")
raise Exception("No embeddings found in API response")

if len(embeddings) != len(texts):
logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})")
raise Exception("Number of embeddings does not match number of input texts")

logger.info(f"Successfully embedded {len(texts)} documents")
return embeddings

except requests.Timeout:
logger.error("Embedding request timed out")
raise Exception("Embedding request timed out after 5 minutes")
except requests.RequestException as e:
logger.error(f"Request failed: {str(e)}", exc_info=True)
raise
except Exception as e:
logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True)
raise

def embed_query(self, text: str) -> List[float]:
if not text or not isinstance(text, str):
raise ValidationError("Invalid query text")

logger.info("Embedding single query text")
return self.embed_documents([text])[0]

return PipelineEmbeddings()


@api_wrapper
def list_all(event: dict, context: dict) -> List[Dict[str, Any]]:
"""Return info on all available repositories.
Expand Down
Loading
Loading