Skip to content

Commit

Permalink
Merge pull request #12 from dataforgoodfr/feat/mistral-7
Browse files Browse the repository at this point in the history
Add support for MistralAI client
  • Loading branch information
samuelrince authored Mar 5, 2024
2 parents 2d9e73f + 269d0c8 commit 8544061
Show file tree
Hide file tree
Showing 11 changed files with 383 additions and 102 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pip install poetry
### Install project

```shell
poetry install --with dev,docs
poetry install --all-extras --with dev,docs
```


Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip install genai_impact
## Basic example for OpenAI

```python
from genai_impact.client_wrapper import OpenAI
from genai_impact import OpenAI

client = OpenAI()

Expand Down
4 changes: 2 additions & 2 deletions genai_impact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from genai_impact.client_wrapper import OpenAI
from genai_impact.client import MistralClient, OpenAI

__all__ = ["OpenAI"]
__all__ = ["OpenAI", "MistralClient"]
9 changes: 0 additions & 9 deletions genai_impact/__main__.py

This file was deleted.

4 changes: 4 additions & 0 deletions genai_impact/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .mistralai_wrapper import MistralClient
from .openai_wrapper import OpenAI

__all__ = ["OpenAI", "MistralClient"]
45 changes: 45 additions & 0 deletions genai_impact/client/mistralai_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, Callable

from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact

try:
from mistralai.client import MistralClient as _MistralClient
from mistralai.models.chat_completion import (
ChatCompletionResponse as _ChatCompletionResponse,
)
except ImportError:
_MistralClient = object()
_ChatCompletionResponse = object()


_MODEL_SIZES = {
"mistral-tiny": 7.3,
"mistral-small": 12.9, # mixtral active parameters count
"mistral-medium": 70,
"mistral-large": 220,
}


class ChatCompletionResponse(_ChatCompletionResponse):
impacts: Impacts


def chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
output_tokens = response.usage.completion_tokens
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletionResponse(**response.model_dump(), impacts=impacts)


class MistralClient(_MistralClient):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

wrap_function_wrapper("mistralai.client", "MistralClient.chat", chat_wrapper)
53 changes: 53 additions & 0 deletions genai_impact/client/openai_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Callable

from openai import OpenAI as _OpenAI
from openai.resources.chat import Completions
from openai.types.chat import ChatCompletion as _ChatCompletion
from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import Impacts, compute_llm_impact

_MODEL_SIZES = {
"gpt-4-0125-preview": None,
"gpt-4-turbo-preview": None,
"gpt-4-1106-preview": None,
"gpt-4-vision-preview": None,
"gpt-4": 220,
"gpt-4-0314": 220,
"gpt-4-0613": 220,
"gpt-4-32k": 220,
"gpt-4-32k-0314": 220,
"gpt-4-32k-0613": 220,
"gpt-3.5-turbo": 20,
"gpt-3.5-turbo-16k": 20,
"gpt-3.5-turbo-0301": 20,
"gpt-3.5-turbo-0613": 20,
"gpt-3.5-turbo-1106": 20,
"gpt-3.5-turbo-0125": 20,
"gpt-3.5-turbo-16k-0613": 20,
}


class ChatCompletion(_ChatCompletion):
impacts: Impacts


def chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
output_tokens = response.usage.completion_tokens
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
return ChatCompletion(**response.model_dump(), impacts=impacts)


class OpenAI(_OpenAI):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

wrap_function_wrapper(
"openai.resources.chat.completions", "Completions.create", chat_wrapper
)
47 changes: 0 additions & 47 deletions genai_impact/client_wrapper.py

This file was deleted.

25 changes: 2 additions & 23 deletions genai_impact/compute_impacts.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,9 @@
from dataclasses import dataclass
from pydantic import BaseModel

ENERGY_PROFILE = 1.17e-4

MODEL_SIZES = {
"gpt-4-0125-preview": None,
"gpt-4-turbo-preview": None,
"gpt-4-1106-preview": None,
"gpt-4-vision-preview": None,
"gpt-4": 200,
"gpt-4-0314": 200,
"gpt-4-0613": 200,
"gpt-4-32k": 200,
"gpt-4-32k-0314": 200,
"gpt-4-32k-0613": 200,
"gpt-3.5-turbo": 20,
"gpt-3.5-turbo-16k": 20,
"gpt-3.5-turbo-0301": 20,
"gpt-3.5-turbo-0613": 20,
"gpt-3.5-turbo-1106": 20,
"gpt-3.5-turbo-0125": 20,
"gpt-3.5-turbo-16k-0613": 20,
}


@dataclass
class Impacts:
class Impacts(BaseModel):
energy: float
energy_unit: str = "Wh"

Expand Down
Loading

0 comments on commit 8544061

Please sign in to comment.