Skip to content

Commit

Permalink
Merge pull request #1108 from guardrails-ai/langchain-core-03
Browse files Browse the repository at this point in the history
  • Loading branch information
zsimjee authored Oct 2, 2024
2 parents 7153745 + 32a04ff commit c7967db
Show file tree
Hide file tree
Showing 6 changed files with 2,576 additions and 2,453 deletions.
4 changes: 2 additions & 2 deletions docs/integrations/langchain.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ This is a comprehensive guide on integrating Guardrails with [LangChain](https:/

## Prerequisites

1. Ensure you have the following langchain packages installed:
1. Ensure you have the following langchain packages installed. Also install Guardrails

```bash
pip install langchain langchain_openai
pip install "guardrails-ai>=0.5.13" langchain langchain_openai
```

2. As a prerequisite we install the necessary validators from the Guardrails Hub:
Expand Down
2 changes: 1 addition & 1 deletion guardrails/classes/llm/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def to_interface(self) -> ILLMResponse:
async_stream_output.append(so)
awaited_stream_output.append(str(async_to_sync(so)))

self.async_stream_output = aiter(async_stream_output) # type: ignore
self.async_stream_output = aiter(async_stream_output) # type: ignore # noqa: F821

return ILLMResponse(
prompt_token_count=self.prompt_token_count, # type: ignore - pyright doesn't understand aliases
Expand Down
22 changes: 11 additions & 11 deletions guardrails/schema/pydantic_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,22 @@ def get_base_model(
type_origin = get_origin(pydantic_class)
key_type_origin = None

if type_origin == list:
if type_origin is list:
item_types = get_args(pydantic_class)
if len(item_types) > 1:
raise ValueError("List data type must have exactly one child.")
item_type = safe_get(item_types, 0)
if not item_type or not issubclass(item_type, BaseModel):
raise ValueError("List item type must be a Pydantic model.")
schema_model = item_type
elif type_origin == dict:
elif type_origin is dict:
key_value_types = get_args(pydantic_class)
value_type = safe_get(key_value_types, 1)
key_type_origin = safe_get(key_value_types, 0)
if not value_type or not issubclass(value_type, BaseModel):
raise ValueError("Dict value type must be a Pydantic model.")
schema_model = value_type
elif type_origin == Union:
elif type_origin is Union:
union_members = get_args(pydantic_class)
model_members = list(filter(is_base_model_type, union_members))
if len(model_members) > 0:
Expand Down Expand Up @@ -141,7 +141,7 @@ def extract_union_member(
field_model, field_type_origin, key_type_origin = try_get_base_model(member)
if not field_model:
return member
if field_type_origin == Union:
if field_type_origin is Union:
union_members = get_args(field_model)
extracted_union_members = []
for m in union_members:
Expand All @@ -157,9 +157,9 @@ def extract_union_member(
json_path=json_path,
aliases=aliases,
)
if field_type_origin == list:
if field_type_origin is list:
return List[extracted_field_model]
elif field_type_origin == dict:
elif field_type_origin is dict:
return Dict[key_type_origin, extracted_field_model] # type: ignore
return extracted_field_model

Expand Down Expand Up @@ -231,7 +231,7 @@ def extract_validators(
field.annotation
)
if field_model:
if field_type_origin == Union:
if field_type_origin is Union:
union_members = list(get_args(field_model))
extracted_union_members = []
for m in union_members:
Expand All @@ -254,11 +254,11 @@ def extract_validators(
json_path=field_path,
aliases=alias_paths,
)
if field_type_origin == list:
if field_type_origin is list:
model.model_fields[field_name].annotation = List[
extracted_field_model
]
elif field_type_origin == dict:
elif field_type_origin is dict:
model.model_fields[field_name].annotation = Dict[
key_type_origin, extracted_field_model # type: ignore
]
Expand All @@ -276,7 +276,7 @@ def pydantic_to_json_schema(
json_schema = pydantic_class.model_json_schema()
json_schema["title"] = pydantic_class.__name__

if type_origin == list:
if type_origin is list:
json_schema = {
"title": f"Array<{json_schema.get('title')}>",
"type": "array",
Expand All @@ -294,7 +294,7 @@ def pydantic_model_to_schema(
schema_model, type_origin, _key_type_origin = get_base_model(pydantic_class)

processed_schema.output_type = (
OutputTypes.LIST if type_origin == list else OutputTypes.DICT
OutputTypes.LIST if type_origin is list else OutputTypes.DICT
)

model = extract_validators(schema_model, processed_schema, "$")
Expand Down
4 changes: 2 additions & 2 deletions guardrails/utils/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def convert_pydantic_model_to_openai_fn(
schema_model = model

type_origin = get_origin(model)
if type_origin == list:
if type_origin is list:
item_types = get_args(model)
if len(item_types) > 1:
raise ValueError("List data type must have exactly one child.")
Expand All @@ -41,7 +41,7 @@ def convert_pydantic_model_to_openai_fn(
json_schema = schema_model.model_json_schema()
json_schema["title"] = schema_model.__name__

if type_origin == list:
if type_origin is list:
json_schema = {
"title": f"Array<{json_schema.get('title')}>",
"type": "array",
Expand Down
Loading

0 comments on commit c7967db

Please sign in to comment.