Skip to content

Commit

Permalink
linting, formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ajgray-stripe committed Nov 8, 2024
1 parent aaca598 commit 59e5716
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
14 changes: 9 additions & 5 deletions packages/exchange/src/exchange/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from typing import Callable
from functools import wraps
from typing import Callable

from exchange.observers.base import ObserverManager

def observe_wrapper(*args, **kwargs) -> Callable:

def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003
"""Decorator to wrap a function with all registered observer plugins, dynamically fetched."""
def wrapper(func):

def wrapper(func: Callable) -> Callable:
@wraps(func)
def dynamic_wrapped(*func_args, **func_kwargs):
def dynamic_wrapped(*func_args, **func_kwargs) -> Callable: # noqa: ANN002, ANN003
wrapped = func
for observer in ObserverManager.get_instance()._observers:
wrapped = observer.observe_wrapper(*args, **kwargs)(wrapped)
return wrapped(*func_args, **func_kwargs)

return dynamic_wrapped
return wrapper

return wrapper
8 changes: 5 additions & 3 deletions packages/exchange/src/exchange/observers/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, Type


class Observer(ABC):
@abstractmethod
def initialize(self) -> None:
pass

@abstractmethod
def observe_wrapper(*args, **kwargs) -> Callable:
def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003
pass

@abstractmethod
def finalize(self) -> None:
pass


class ObserverManager:
_instance = None
_observers: list[Observer] = []

@classmethod
def get_instance(cls):
def get_instance(cls: Type["ObserverManager"]) -> "ObserverManager":
if cls._instance is None:
cls._instance = cls()
return cls._instance
Expand Down
4 changes: 2 additions & 2 deletions packages/exchange/src/exchange/observers/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def initialize_with_disabled_tracing(self) -> None:
langfuse_context.configure(enabled=False)
self.tracing = False

def observe_wrapper(self, *args, **kwargs) -> Callable:
def observe_wrapper(self, *args, **kwargs) -> Callable: # noqa: ANN002, ANN003
def _wrapper(fn: Callable) -> Callable:
if self.tracing and auth_check():

@wraps(fn)
def wrapped_fn(*fargs, **fkwargs) -> Callable:
def wrapped_fn(*fargs, **fkwargs) -> Callable: # noqa: ANN002, ANN003
# group all traces under the same session
if fn.__name__ == "reply":
langfuse_context.update_current_trace(session_id=fargs[0].name)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from goose.profile import ToolkitSpec, ObserverSpec


def test_profile_info(profile_factory):
profile = profile_factory(
{
"provider": "provider",
"processor": "processor",
"toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")],
"observers": [ObserverSpec(name="test.plugin")]
"observers": [ObserverSpec(name="test.plugin")],
}
)
assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github observers: test.plugin"
assert (
profile.profile_info()
== "provider:provider, processor:processor toolkits: developer, github observers: test.plugin"
)

0 comments on commit 59e5716

Please sign in to comment.