Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aviary support #5661

Merged
merged 4 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions docs/modules/models/llms/integrations/aviary.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9597802c",
"metadata": {},
"source": [
"# Aviary\n",
"\n",
"[Aviary](https://www.anyscale.com/) is an open source tooklit for evaluating and deploying production open source LLMs. \n",
"\n",
"This example goes over how to use LangChain to interact with `Aviary`. You can try Aviary out [https://aviary.anyscale.com](here).\n",
"\n",
"You can find out more about Aviary at https://github.com/ray-project/aviary. \n",
"\n",
"One Aviary instance can serve multiple models. You can get a list of the available models by using the cli:\n",
"\n",
"`% aviary models`\n",
"\n",
"Or you can connect directly to the endpoint and get a list of available models by using the `/models` endpoint.\n",
"\n",
"The constructor requires a url for an Aviary backend, and optionally a token to validate the connection. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6fb585dd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"from langchain.llms import Aviary\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3fec5a59",
"metadata": {},
"outputs": [],
"source": [
"llm = Aviary(model='amazon/LightGPT', aviary_url=os.environ['AVIARY_URL'], aviary_token=os.environ['AVIARY_TOKEN'])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4efd54dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Love is an emotion that involves feelings of attraction, affection and empathy for another person. It can also refer to a deep bond between two people or groups of people. Love can be expressed in many different ways, such as through words, actions, gestures, music, art, literature, and other forms of communication.\n"
]
}
],
"source": [
"result = llm.predict('What is the meaning of love?')\n",
"print(result) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e526b6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 3 additions & 0 deletions langchain/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain.llms.aleph_alpha import AlephAlpha
from langchain.llms.anthropic import Anthropic
from langchain.llms.anyscale import Anyscale
from langchain.llms.aviary import Aviary
from langchain.llms.bananadev import Banana
from langchain.llms.base import BaseLLM
from langchain.llms.beam import Beam
Expand Down Expand Up @@ -47,6 +48,7 @@
"Anthropic",
"AlephAlpha",
"Anyscale",
"Aviary",
"Banana",
"Beam",
"Bedrock",
Expand Down Expand Up @@ -94,6 +96,7 @@
"aleph_alpha": AlephAlpha,
"anthropic": Anthropic,
"anyscale": Anyscale,
"aviary": Aviary,
"bananadev": Banana,
"beam": Beam,
"cerebriumai": CerebriumAI,
Expand Down
136 changes: 136 additions & 0 deletions langchain/llms/aviary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional

import requests
from pydantic import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env

TIMEOUT = 60


class Aviary(LLM):
"""Allow you to use an Aviary.

Aviary is a backend for hosted models. You can
find out more about aviary at
http://github.com/ray-project/aviary

Has no dependencies, since it connects to backend
directly.

To get a list of the models supported on an
aviary, follow the instructions on the web site to
install the aviary CLI and then use:
`aviary models`

You must at least specify the environment
variable or parameter AVIARY_URL.

You may optionally specify the environment variable
or parameter AVIARY_TOKEN.

Example:
.. code-block:: python

from langchain.llms import Aviary
light = Aviary(aviary_url='AVIARY_URL',
model='amazon/LightGPT')

result = light.predict('How do you make fried rice?')
"""

model: str
aviary_url: str
aviary_token: str = Field("", exclude=True)

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
if not aviary_url.endswith("/"):
aviary_url += "/"
values["aviary_url"] = aviary_url
aviary_token = get_from_dict_or_env(
values, "aviary_token", "AVIARY_TOKEN", default=""
)
values["aviary_token"] = aviary_token

aviary_endpoint = aviary_url + "models"
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
try:
response = requests.get(aviary_endpoint, headers=headers)
result = response.json()
# Confirm model is available
if values["model"] not in result:
raise ValueError(
f"{aviary_url} does not support model {values['model']}."
)

except requests.exceptions.RequestException as e:
raise ValueError(e)

return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"aviary_url": self.aviary_url,
"aviary_token": self.aviary_token,
}

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "aviary"

@property
def headers(self) -> Dict[str, str]:
if self.aviary_token:
return {"Authorization": f"Bearer {self.aviary_token}"}
else:
return {}

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Aviary
Args:
prompt: The prompt to pass into the model.

Returns:
The string generated by the model.

Example:
.. code-block:: python

response = aviary("Tell me a joke.")
"""
url = self.aviary_url + "query/" + self.model.replace("/", "--")
response = requests.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=TIMEOUT,
)
try:
text = response.json()[self.model]["generated_text"]
except requests.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON from {url}. Text response: {response.text}",
) from e
if stop:
text = enforce_stop_tokens(text, stop)
return text
10 changes: 10 additions & 0 deletions tests/integration_tests/llms/test_aviary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Test Anyscale API wrapper."""

from langchain.llms.aviary import Aviary


def test_aviary_call() -> None:
"""Test valid call to Anyscale."""
llm = Aviary(model="test/model")
output = llm("Say bar:")
assert isinstance(output, str)