Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MistralAI client #12

Merged
merged 8 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pip install poetry
### Install project

```shell
poetry install --with dev,docs
poetry install --all-extras --with dev,docs
```


Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip install genai_impact
## Basic example for OpenAI

```python
from genai_impact.client_wrapper import OpenAI
from genai_impact import OpenAI

client = OpenAI()

Expand Down
4 changes: 2 additions & 2 deletions genai_impact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from genai_impact.client_wrapper import OpenAI
from genai_impact.client import MistralClient, OpenAI

__all__ = ["OpenAI"]
__all__ = ["OpenAI", "MistralClient"]
9 changes: 0 additions & 9 deletions genai_impact/__main__.py

This file was deleted.

4 changes: 4 additions & 0 deletions genai_impact/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .mistralai_wrapper import MistralClient
from .openai_wrapper import OpenAI

__all__ = ["OpenAI", "MistralClient"]
45 changes: 45 additions & 0 deletions genai_impact/client/mistralai_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, Callable

from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact

try:
from mistralai.client import MistralClient as _MistralClient
from mistralai.models.chat_completion import (
ChatCompletionResponse as _ChatCompletionResponse,
)
except ImportError:
_MistralClient = object()
_ChatCompletionResponse = object()


_MODEL_SIZES = {
"mistral-tiny": 7.3,
"mistral-small": 12.9, # mixtral active parameters count
"mistral-medium": 70,
"mistral-large": 220,
}


class ChatCompletionResponse(_ChatCompletionResponse):
impacts: Impacts


def chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
output_tokens = response.usage.completion_tokens
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletionResponse(**response.model_dump(), impacts=impacts)


class MistralClient(_MistralClient):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

wrap_function_wrapper("mistralai.client", "MistralClient.chat", chat_wrapper)
53 changes: 53 additions & 0 deletions genai_impact/client/openai_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Callable

from openai import OpenAI as _OpenAI
from openai.resources.chat import Completions
from openai.types.chat import ChatCompletion as _ChatCompletion
from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact

_MODEL_SIZES = {
"gpt-4-0125-preview": None,
"gpt-4-turbo-preview": None,
"gpt-4-1106-preview": None,
"gpt-4-vision-preview": None,
"gpt-4": 220,
"gpt-4-0314": 220,
"gpt-4-0613": 220,
"gpt-4-32k": 220,
"gpt-4-32k-0314": 220,
"gpt-4-32k-0613": 220,
"gpt-3.5-turbo": 20,
"gpt-3.5-turbo-16k": 20,
"gpt-3.5-turbo-0301": 20,
"gpt-3.5-turbo-0613": 20,
"gpt-3.5-turbo-1106": 20,
"gpt-3.5-turbo-0125": 20,
"gpt-3.5-turbo-16k-0613": 20,
}


class ChatCompletion(_ChatCompletion):
impacts: Impacts


def chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
output_tokens = response.usage.completion_tokens
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletion(**response.model_dump(), impacts=impacts)


class OpenAI(_OpenAI):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

wrap_function_wrapper(
"openai.resources.chat.completions", "Completions.create", chat_wrapper
)
47 changes: 0 additions & 47 deletions genai_impact/client_wrapper.py

This file was deleted.

25 changes: 2 additions & 23 deletions genai_impact/compute_impacts.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,9 @@
from dataclasses import dataclass
from pydantic import BaseModel

ENERGY_PROFILE = 1.17e-4

MODEL_SIZES = {
"gpt-4-0125-preview": None,
"gpt-4-turbo-preview": None,
"gpt-4-1106-preview": None,
"gpt-4-vision-preview": None,
"gpt-4": 200,
"gpt-4-0314": 200,
"gpt-4-0613": 200,
"gpt-4-32k": 200,
"gpt-4-32k-0314": 200,
"gpt-4-32k-0613": 200,
"gpt-3.5-turbo": 20,
"gpt-3.5-turbo-16k": 20,
"gpt-3.5-turbo-0301": 20,
"gpt-3.5-turbo-0613": 20,
"gpt-3.5-turbo-1106": 20,
"gpt-3.5-turbo-0125": 20,
"gpt-3.5-turbo-16k-0613": 20,
}


@dataclass
class Impacts:
class Impacts(BaseModel):
energy: float
energy_unit: str = "Wh"

Expand Down
Loading
Loading