Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add GuardrailsEngine for llama index #1005

Closed
wants to merge 10 commits into from
Empty file.
5 changes: 5 additions & 0 deletions guardrails/integrations/llama_index/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from guardrails.integrations.llama_index.guardrails_engine import (
GuardrailsEngine,
)

__all__ = ["GuardrailsEngine"]
238 changes: 238 additions & 0 deletions guardrails/integrations/llama_index/guardrails_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# pyright: reportMissingImports=false

from typing import Any, Optional, Dict, List, Union, TYPE_CHECKING, cast
from guardrails import Guard
from guardrails.errors import ValidationError
from guardrails.classes.validation_outcome import ValidationOutcome
from guardrails.decorators.experimental import experimental
import importlib.util

LLAMA_INDEX_AVAILABLE = importlib.util.find_spec("llama_index") is not None

if TYPE_CHECKING or LLAMA_INDEX_AVAILABLE:
from llama_index.core.query_engine import BaseQueryEngine
from llama_index.core.chat_engine.types import (
BaseChatEngine,
AGENT_CHAT_RESPONSE_TYPE,
AgentChatResponse,
StreamingAgentChatResponse,
)
from llama_index.core.schema import QueryBundle
from llama_index.core.callbacks import CallbackManager
from llama_index.core.base.response.schema import (
RESPONSE_TYPE,
Response,
StreamingResponse,
AsyncStreamingResponse,
PydanticResponse,
)
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.prompts.mixin import PromptMixinType


class GuardrailsEngine(BaseQueryEngine, BaseChatEngine):
def __init__(
self,
engine: Union["BaseQueryEngine", "BaseChatEngine"],
guard: Guard,
guard_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional["CallbackManager"] = None,
):
try:
import llama_index # noqa: F401
except ImportError:
raise ImportError(
"llama_index is not installed. Please install it with "
"`pip install llama-index` to use GuardrailsEngine."
)

self._engine = engine
self._guard = guard
self._guard_kwargs = guard_kwargs or {}
self._engine_response = None
super().__init__(callback_manager)

@property
def guard(self) -> Guard:
return self._guard

def engine_api(self, prompt: str, **kwargs) -> str:
if isinstance(self._engine, BaseQueryEngine):
response = self._engine.query(prompt)
elif isinstance(self._engine, BaseChatEngine):
chat_history = kwargs.get("chat_history", [])
response = self._engine.chat(prompt, chat_history)
else:
raise ValueError("Unsupported engine type")

self._engine_response = response
return str(response)

@experimental
def _query(self, query_bundle: "QueryBundle") -> "RESPONSE_TYPE":
if not isinstance(self._engine, BaseQueryEngine):
raise ValueError(
"Cannot perform query with a ChatEngine. Use chat() method instead."
)
if isinstance(query_bundle, str):
query_bundle = QueryBundle(query_bundle)
try:
validated_output = self.guard(
llm_api=self.engine_api,
prompt=query_bundle.query_str,
**self._guard_kwargs,
)

validated_output = cast(ValidationOutcome, validated_output)
self._engine_response = cast(RESPONSE_TYPE, self._engine_response)
if not validated_output.validation_passed:
raise ValidationError(f"Validation failed: {validated_output.error}")
self._update_response_metadata(validated_output)
if validated_output.validation_passed:
if isinstance(self._engine_response, Response):
self._engine_response.response = validated_output.validated_output
elif isinstance(
self._engine_response, (StreamingResponse, AsyncStreamingResponse)
):
self._engine_response.response_txt = (
validated_output.validated_output
)
elif isinstance(self._engine_response, PydanticResponse):
if self._engine_response.response:
import json

json_str = (
validated_output.validated_output
if isinstance(validated_output.validated_output, str)
else json.dumps(validated_output.validated_output)
)
self._engine_response.response = self._engine_response.response.__class__.model_validate_json( # noqa: E501
json_str
)
else:
raise ValueError("Unsupported response type")
except ValidationError as e:
raise ValidationError(f"Validation failed: {str(e)}")
except Exception as e:
raise RuntimeError(f"An error occurred during query processing: {str(e)}")
return self._engine_response

@experimental
def chat(
self, message: str, chat_history: Optional[List["ChatMessage"]] = None
) -> "AGENT_CHAT_RESPONSE_TYPE":
if not isinstance(self._engine, BaseChatEngine):
raise ValueError(
"Cannot perform chat with a QueryEngine. Use query() method instead."
)
if chat_history is None:
chat_history = []
try:
validated_output = self.guard(
llm_api=self.engine_api,
prompt=message,
chat_history=chat_history,
**self._guard_kwargs,
)
response = self._create_chat_response(validated_output)
if response is None:
raise ValueError("Failed to create a valid chat response")

return cast(AGENT_CHAT_RESPONSE_TYPE, response)
except ValidationError as e:
raise ValidationError(f"Validation failed: {str(e)}")
except Exception as e:
raise RuntimeError(f"An error occurred during chat processing: {str(e)}")

def _update_response_metadata(self, validated_output):
if self._engine_response is None:
return
self._engine_response = cast(RESPONSE_TYPE, self._engine_response)

metadata_update = {
"validation_passed": validated_output.validation_passed,
"validated_output": validated_output.validated_output,
"error": validated_output.error,
"raw_llm_output": validated_output.raw_llm_output,
}

if self._engine_response.metadata is None:
self._engine_response.metadata = {}
self._engine_response.metadata.update(metadata_update)

def _create_chat_response(
self, validated_output
) -> Optional[AGENT_CHAT_RESPONSE_TYPE]:
if validated_output.validation_passed:
content = validated_output.validated_output
else:
content = "I'm sorry, but I couldn't generate a valid response."

if self._engine_response is None:
return None

metadata_update = {
"validation_passed": validated_output.validation_passed,
"validated_output": validated_output.validated_output,
"error": validated_output.error,
"raw_llm_output": validated_output.raw_llm_output,
}

if isinstance(self._engine_response, AgentChatResponse):
if self._engine_response.metadata is None:
self._engine_response.metadata = {}
self._engine_response.metadata.update(metadata_update)
elif isinstance(self._engine_response, StreamingAgentChatResponse):
for key, value in metadata_update.items():
setattr(self._engine_response, key, value)

self._engine_response.response = content
return self._engine_response

async def _aquery(self, query_bundle: "QueryBundle") -> "RESPONSE_TYPE":
"""Async version of _query."""
return self._query(query_bundle)

async def achat(
self, message: str, chat_history: Optional[List["ChatMessage"]] = None
):
"""Async version of chat."""
raise NotImplementedError(
"Async chat is not supported in the GuardrailsQueryEngine."
)

def stream_chat(
self, message: str, chat_history: Optional[List["ChatMessage"]] = None
):
"""Stream chat responses."""
raise NotImplementedError(
"Stream chat is not supported in the GuardrailsQueryEngine."
)

async def astream_chat(
self, message: str, chat_history: Optional[List["ChatMessage"]] = None
):
"""Async stream chat responses."""
raise NotImplementedError(
"Async stream chat is not supported in the GuardrailsQueryEngine."
)

def reset(self):
"""Reset the chat history."""
if isinstance(self._engine, BaseChatEngine):
self._engine.reset()
else:
raise NotImplementedError("Reset is only available for chat engines.")

@property
def chat_history(self) -> List["ChatMessage"]:
"""Get the chat history."""
if isinstance(self._engine, BaseChatEngine):
return self._engine.chat_history
raise NotImplementedError("Chat history is only available for chat engines.")

def _get_prompt_modules(self) -> "PromptMixinType":
"""Get prompt modules."""
if isinstance(self._engine, BaseQueryEngine):
return self._engine._get_prompt_modules()
return {}
10 changes: 10 additions & 0 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,16 @@ def get_llm_ask(
except ImportError:
pass

if llm_api is not None:
llm_self = getattr(llm_api, "__self__", None)
if (
llm_self is not None
and hasattr(llm_self, "__class__")
and getattr(llm_self.__class__, "__name__", None) == "GuardrailsEngine"
and getattr(llm_api, "__name__", None) == "engine_api"
):
return ArbitraryCallable(*args, llm_api=llm_api, **kwargs)

if llm_api == get_static_openai_create_func():
return OpenAICallable(*args, **kwargs)
if llm_api == get_static_openai_chat_create_func():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest
from guardrails import Guard
from guardrails.errors import ValidationError
from typing import List, Optional
from tests.integration_tests.test_assets.validators import RegexMatch

try:
from llama_index.core.query_engine import BaseQueryEngine
from llama_index.core.chat_engine.types import BaseChatEngine, AgentChatResponse
from llama_index.core.schema import QueryBundle
from llama_index.core.base.response.schema import Response
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.prompts.mixin import PromptMixinType
from llama_index.core.callbacks import CallbackManager
from guardrails.integrations.llama_index.guardrails_engine import GuardrailsEngine

class MockQueryEngine(BaseQueryEngine):
def __init__(self, callback_manager: Optional[CallbackManager] = None):
super().__init__(callback_manager)

def _query(self, query_bundle: QueryBundle) -> Response:
return Response(response="Mock response")

async def _aquery(self, query_bundle: QueryBundle) -> Response:
return Response(response="Mock async query response")

def _get_prompt_modules(self) -> PromptMixinType:
return {}

class MockChatEngine(BaseChatEngine):
def chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
return AgentChatResponse(response="Mock response")

async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> AgentChatResponse:
return AgentChatResponse(response="Mock async chat response")

def stream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
):
yield AgentChatResponse(response="Mock stream chat response")

async def astream_chat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
):
yield AgentChatResponse(response="Mock async stream chat response")

@property
def chat_history(self) -> List[ChatMessage]:
return []

def reset(self):
pass
except ImportError:
pytest.skip("llama_index is not installed", allow_module_level=True)

pytest.importorskip("llama_index")


@pytest.fixture
def guard():
return Guard().use(RegexMatch("Mock response", match_type="search"))


def test_guardrails_engine_init(guard):
engine = MockQueryEngine()
guardrails_engine = GuardrailsEngine(engine, guard)
assert isinstance(guardrails_engine, GuardrailsEngine)
assert guardrails_engine.guard == guard


def test_guardrails_engine_query(guard):
engine = MockQueryEngine()
guardrails_engine = GuardrailsEngine(engine, guard)

result = guardrails_engine._query(QueryBundle(query_str="Mock response"))
assert isinstance(result, Response)
assert result.response == "Mock response"


def test_guardrails_engine_query_validation_failure(guard):
engine = MockQueryEngine()
guardrails_engine = GuardrailsEngine(engine, guard)

engine._query = lambda _: Response(response="Invalid response")

with pytest.raises(ValidationError, match="Validation failed"):
guardrails_engine._query(QueryBundle(query_str="Invalid query"))


def test_guardrails_engine_chat(guard):
engine = MockChatEngine()
guardrails_engine = GuardrailsEngine(engine, guard)

result = guardrails_engine.chat("Mock response")
assert isinstance(result, AgentChatResponse)
assert result.response == "Mock response"


def test_guardrails_engine_unsupported_engine(guard):
class UnsupportedEngine:
pass

engine = UnsupportedEngine()
guardrails_engine = GuardrailsEngine(engine, guard)

with pytest.raises(ValueError, match="Unsupported engine type"):
guardrails_engine.engine_api("Test prompt")
Loading