Skip to content

Commit

Permalink
chore: run pre-commits
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelrince committed Mar 5, 2024
1 parent 3206cf1 commit 269d0c8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 23 deletions.
8 changes: 2 additions & 6 deletions genai_impact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from genai_impact.client import OpenAI
from genai_impact.client import MistralClient
from genai_impact.client import MistralClient, OpenAI

__all__ = [
"OpenAI",
"MistralClient"
]
__all__ = ["OpenAI", "MistralClient"]
7 changes: 2 additions & 5 deletions genai_impact/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .openai_wrapper import OpenAI
from .mistralai_wrapper import MistralClient
from .openai_wrapper import OpenAI

__all__ = [
"OpenAI",
"MistralClient"
]
__all__ = ["OpenAI", "MistralClient"]
18 changes: 9 additions & 9 deletions genai_impact/client/mistralai_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from typing import Callable, Any
from typing import Any, Callable

from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import compute_llm_impact, Impacts
from genai_impact.compute_impacts import Impacts, compute_llm_impact

try:
from mistralai.client import MistralClient as _MistralClient
from mistralai.models.chat_completion import ChatCompletionResponse as _ChatCompletionResponse
from mistralai.models.chat_completion import (
ChatCompletionResponse as _ChatCompletionResponse,
)
except ImportError:
_MistralClient = object()
_ChatCompletionResponse = object()


_MODEL_SIZES = {
"mistral-tiny": 7.3,
"mistral-small": 12.9, # mixtral active parameters count
"mistral-small": 12.9, # mixtral active parameters count
"mistral-medium": 70,
"mistral-large": 220
"mistral-large": 220,
}


Expand All @@ -25,7 +27,7 @@ class ChatCompletionResponse(_ChatCompletionResponse):


def chat_wrapper(
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
wrapped: Callable, instance: _MistralClient, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletionResponse:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
Expand All @@ -40,6 +42,4 @@ class MistralClient(_MistralClient):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

wrap_function_wrapper(
"mistralai.client", "MistralClient.chat", chat_wrapper
)
wrap_function_wrapper("mistralai.client", "MistralClient.chat", chat_wrapper)
5 changes: 2 additions & 3 deletions genai_impact/client/openai_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from openai.types.chat import ChatCompletion as _ChatCompletion
from wrapt import wrap_function_wrapper

from genai_impact.compute_impacts import compute_llm_impact, Impacts

from genai_impact.compute_impacts import Impacts, compute_llm_impact

_MODEL_SIZES = {
"gpt-4-0125-preview": None,
Expand Down Expand Up @@ -34,7 +33,7 @@ class ChatCompletion(_ChatCompletion):


def chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
wrapped: Callable, instance: Completions, args: Any, kwargs: Any # noqa: ARG001
) -> ChatCompletion:
response = wrapped(*args, **kwargs)
model_size = _MODEL_SIZES.get(response.model)
Expand Down

0 comments on commit 269d0c8

Please sign in to comment.