Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add LiteLLM support #549

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions aider/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def create(

if not skip_model_availabily_check and not main_model.always_available:
if not check_model_availability(io, client, main_model):
fallback_model = models.GPT35_0125
fallback_model = models.Model.create("gpt-3.5-turbo-0125")
io.tool_error(
f"API key does not support {main_model.name}, falling back to"
f" {fallback_model.name}"
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(
if use_git:
try:
self.repo = GitRepo(
self.io, fnames, git_dname, aider_ignore_file, client=self.client
self.io, fnames, git_dname, main_model, aider_ignore_file, client=self.client
)
self.root = self.repo.root
except FileNotFoundError:
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(

self.summarizer = ChatSummary(
self.client,
models.Model.weak_model(),
self.main_model.get_weak_model(),
self.main_model.max_chat_history_tokens,
)

Expand Down Expand Up @@ -1055,6 +1055,9 @@ def dirty_commit(self):


def check_model_availability(io, client, main_model):
if not hasattr(client, "models"):
return True

try:
available_models = client.models.list()
except openai.NotFoundError:
Expand Down
90 changes: 76 additions & 14 deletions aider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aider.io import InputOutput
from aider.repo import GitRepo
from aider.versioncheck import check_version
from aider.models.litellm import LITELLM_SPEC

from .dump import dump # noqa: F401

Expand Down Expand Up @@ -163,7 +164,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
core_group.add_argument(
"--model",
metavar="MODEL",
default=default_model,
help=f"Specify the model to use for the main chat (default: {default_model})",
)
core_group.add_argument(
Expand All @@ -190,16 +190,16 @@ def main(argv=None, input=None, output=None, force_git_root=None):
const=default_4_turbo_model,
help=f"Use {default_4_turbo_model} model for the main chat",
)
default_3_model = models.GPT35_0125
default_3_model = "gpt-3.5-turbo-0125"
core_group.add_argument(
"--35turbo",
"--35-turbo",
"--3",
"-3",
action="store_const",
dest="model",
const=default_3_model.name,
help=f"Use {default_3_model.name} model for the main chat",
const=default_3_model,
help=f"Use {default_3_model} model for the main chat",
)
core_group.add_argument(
"--voice-language",
Expand Down Expand Up @@ -247,6 +247,11 @@ def main(argv=None, input=None, output=None, force_git_root=None):
const="https://openrouter.ai/api/v1",
help="Specify the api base url as https://openrouter.ai/api/v1",
)
model_group.add_argument(
"--litellm",
action="store_true",
help="Use litellm instead of openai",
)
model_group.add_argument(
"--edit-format",
metavar="EDIT_FORMAT",
Expand Down Expand Up @@ -543,6 +548,8 @@ def main(argv=None, input=None, output=None, force_git_root=None):
check_gitignore(git_root, io)

def scrub_sensitive_info(text):
if not args.openai_api_key:
return text
# Replace sensitive information with placeholder
return text.replace(args.openai_api_key, "***")

Expand All @@ -557,18 +564,73 @@ def scrub_sensitive_info(text):

io.tool_output(*map(scrub_sensitive_info, sys.argv), log_only=True)

if not args.openai_api_key:
if os.name == "nt":
io.tool_error(
"No OpenAI API key provided. Use --openai-api-key or setx OPENAI_API_KEY."
)
else:
io.tool_error(
"No OpenAI API key provided. Use --openai-api-key or export OPENAI_API_KEY."
)
if args.litellm:
if not args.model:
io.tool_error("You must specify --model or AIDER_MODEL environment variable when using --litellm.")
return 1

elif not (args.openai_api_key or args.openai_api_base) and \
args.model is not None and \
LITELLM_SPEC is not None:
io.tool_output(f"OpenAI key not provided, using LiteLLM instead.")
args.litellm = True

elif not args.openai_api_key:
export_kw = "setx" if os.name == "nt" else "export"
io.tool_error(
f"No OpenAI API key provided. Use --openai-api-key or {export_kw} OPENAI_API_KEY."
)
return 1

if args.openai_api_type == "azure":
if not args.model:
args.model = default_model

if args.litellm:
if LITELLM_SPEC is None:
io.tool_error("LiteLLM is not installed. Install it with `pip install litellm`.")
return 1

io.tool_output("LiteLLM is enabled.")

packageRootDir = os.path.dirname(LITELLM_SPEC.origin)
modelPricesBackupName = "model_prices_and_context_window_backup.json"
modelPricesPath = os.path.join(packageRootDir, modelPricesBackupName)

# LiteLLM keeps a backup of the model prices for when the network is
# down, but it's otherwise not used unless LITELLM_LOCAL_MODEL_COST_MAP
# is set to True. To keep Aider's startup time fast, prevent the network
# request unless the model prices haven‘t been updated in 12 hours.
from datetime import datetime, timedelta
timeSinceModified = datetime.now() - datetime.fromtimestamp(os.path.getmtime(modelPricesPath))

if timeSinceModified < timedelta(hours=12):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
io.tool_output("LiteLLM model list backup is recent, using it instead of network request.")
else:
def strfdelta(delta):
days = delta.days
hours, remainder = divmod(delta.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
time_parts = [f"{days} day{'' if days == 1 else 's'}" if days else "",
f"{hours} hour{'' if hours == 1 else 's'}" if hours else "",
f"{minutes} minute{'' if minutes == 1 else 's'}" if minutes else "",
f"{seconds} second{'' if seconds == 1 else 's'}"]
time_parts = [part for part in time_parts if part]
return ", ".join(time_parts)

formattedTime = strfdelta(timeSinceModified)
io.tool_output(f"LiteLLM model list backup last updated {formattedTime} ago, updating now.")

# Since LiteLLM never updates its backup after hitting the
# network, we'll have to do it ourselves.
from litellm import model_cost
with open(modelPricesPath, "w") as f:
import json
json.dump(model_cost, f)

from litellm import LiteLLM
client = LiteLLM()
elif args.openai_api_type == "azure":
client = openai.AzureOpenAI(
api_key=args.openai_api_key,
azure_endpoint=args.openai_api_base,
Expand Down
8 changes: 1 addition & 7 deletions aider/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
from .openai import OpenAIModel
from .openrouter import OpenRouterModel

GPT4 = Model.create("gpt-4")
GPT35 = Model.create("gpt-3.5-turbo")
GPT35_0125 = Model.create("gpt-3.5-turbo-0125")

DEFAULT_MODEL_NAME = "gpt-4-1106-preview"

__all__ = [
Model,
OpenAIModel,
OpenRouterModel,
GPT4,
GPT35,
GPT35_0125,
]
67 changes: 67 additions & 0 deletions aider/models/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import importlib.util
import logging
import tiktoken

from .model import Model

logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger("aider-litellm")

LITELLM_SPEC = importlib.util.find_spec("litellm")

model_aliases = {
# claude-3
"opus": "claude-3-opus-20240229",
"sonnet": "claude-3-sonnet-20240229",
"haiku": "claude-3-haiku-20240307",
# gemini-1.5-pro
"gemini": "gemini/gemini-1.5-pro-latest",
# gpt-3.5
"gpt-3.5": "gpt-3.5-turbo-0613",
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
# gpt-4
"gpt-4": "gpt-4-0613",
"gpt-4-32k": "gpt-4-32k-0613",
}

models_info = None

class LiteLLMModel(Model):
def __init__(self, name):
model_id = name
if name in model_aliases:
model_id = model_aliases[name]

from litellm import model_cost

model_data = model_cost.get(model_id)
if not model_data:
# For gemini 1.5 pro to work, LiteLLM appears to need the "-latest"
# part included in the model name, but it's not included in the list
# of supported models that way, so finesse it here
if model_id == "gemini/gemini-1.5-pro-latest":
model_data = model_cost.get("gemini/gemini-1.5-pro")
if not model_data:
raise ValueError(f"Unsupported model: {model_id}")
else:
raise ValueError(f"Unsupported model: {model_id}")

self.tokenizer = tiktoken.get_encoding("cl100k_base")

self.name = model_id
self.max_context_tokens = model_data.get("max_input_tokens")
self.prompt_price = model_data.get("input_cost_per_token") * 100
self.completion_price = model_data.get("output_cost_per_token") * 100

is_high_end = model_id.startswith("gpt-4") or model_id.startswith("claude-3-opus")

self.edit_format = "udiff" if is_high_end else "whole"
self.use_repo_map = is_high_end
self.send_undo_reply = is_high_end

# set the history token limit
if self.max_context_tokens < 32 * 1024:
self.max_chat_history_tokens = 1024
else:
self.max_chat_history_tokens = 2 * 1024
10 changes: 7 additions & 3 deletions aider/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ class Model:
def create(cls, name, client=None):
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
from .litellm import LiteLLMModel

if client and not hasattr(client, "base_url"):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this... I setup LiteLLM in docker on a VM for my team to use, the idea being we can setup the models in one place any everyone has access (and all our projects), but this would require setting a URL or I guess a hacky port forward, would be nice to be able to set a URL for litellm in-case its not running locally.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition doesn't preclude setting a base_url for LiteLLM, but I understand why you'd think that. It's just duck-typing to detect LiteLLM client, which doesn't have a base_url property. There's probably a better way (with more clarity) to detect that, so I'll see what I can do.

return LiteLLMModel(name)
if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name)
return OpenAIModel(name)
Expand All @@ -38,9 +41,10 @@ def strong_model():
def weak_model():
return Model.create("gpt-3.5-turbo-0125")

@staticmethod
def commit_message_models():
return [Model.weak_model()]
def get_weak_model(self):
if self.use_repo_map:
return Model.weak_model()
return self

def token_count(self, messages):
if not self.tokenizer:
Expand Down
9 changes: 4 additions & 5 deletions aider/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class GitRepo:
aider_ignore_spec = None
aider_ignore_ts = 0

def __init__(self, io, fnames, git_dname, aider_ignore_file=None, client=None):
def __init__(self, io, fnames, git_dname, main_model=None, aider_ignore_file=None, client=None):
self.main_model = main_model
self.client = client
self.io = io

Expand Down Expand Up @@ -120,10 +121,8 @@ def get_commit_message(self, diffs, context):
dict(role="user", content=content),
]

for model in models.Model.commit_message_models():
commit_message = simple_send_with_retries(self.client, model.name, messages)
if commit_message:
break
commit_model = self.main_model.get_weak_model()
commit_message = simple_send_with_retries(self.client, commit_model.name, messages)

if not commit_message:
self.io.tool_error("Failed to generate commit message!")
Expand Down
Loading
Loading