Skip to content

Commit

Permalink
feat: add tracing configuration in config.toml
Browse files Browse the repository at this point in the history
> implement get_tracing_callbacks for system wide use cases

fix: proper config parsing
  • Loading branch information
maciejmajek committed Sep 30, 2024
1 parent 49baec9 commit 07659f2
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
10 changes: 10 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@ simple_model = "llama3.1"
complex_model = "llama3.1:70b"
embeddings_model = "llama3.1"
base_url = "http://localhost:11434"

[tracing]
project = "rai"

[tracing.langfuse]
use_langfuse = false
host = "https://cloud.langfuse.com"

[tracing.langsmith]
use_langsmith = false
55 changes: 54 additions & 1 deletion src/rai/rai/utils/model_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass
from typing import Literal
from typing import List, Literal

import tomli
from langchain_core.callbacks.base import BaseCallbackHandler


@dataclass
Expand All @@ -40,12 +42,31 @@ class OllamaConfig(ModelConfig):
base_url: str


@dataclass
class LangfuseConfig:
use_langfuse: bool
host: str


@dataclass
class LangsmithConfig:
use_langsmith: bool


@dataclass
class TracingConfig:
project: str
langfuse: LangfuseConfig
langsmith: LangsmithConfig


@dataclass
class RAIConfig:
vendor: VendorConfig
aws: AWSConfig
openai: ModelConfig
ollama: OllamaConfig
tracing: TracingConfig


def load_config() -> RAIConfig:
Expand All @@ -56,6 +77,11 @@ def load_config() -> RAIConfig:
aws=AWSConfig(**config_dict["aws"]),
openai=ModelConfig(**config_dict["openai"]),
ollama=OllamaConfig(**config_dict["ollama"]),
tracing=TracingConfig(
project=config_dict["tracing"]["project"],
langfuse=LangfuseConfig(**config_dict["tracing"]["langfuse"]),
langsmith=LangsmithConfig(**config_dict["tracing"]["langsmith"]),
),
)


Expand Down Expand Up @@ -108,3 +134,30 @@ def get_embeddings_model():
)
else:
raise ValueError(f"Unknown embeddings vendor: {vendor}")


def get_tracing_callbacks() -> List[BaseCallbackHandler]:
config = load_config()
callbacks: List[BaseCallbackHandler] = []
if config.tracing.langfuse.use_langfuse:
from langfuse.callback import CallbackHandler # type: ignore

public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None)
secret_key = os.getenv("LANGFUSE_SECRET_KEY", None)
if public_key is None or secret_key is None:
raise ValueError("LANGFUSE_PUBLIC_KEY or LANGFUSE_SECRET_KEY is not set")

callback = CallbackHandler(
public_key=public_key,
secret_key=secret_key,
host=config.tracing.langfuse.host,
)
callbacks.append(callback)

if config.tracing.langsmith.use_langsmith:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = config.tracing.project
api_key = os.getenv("LANGCHAIN_API_KEY", None)
if api_key is None:
raise ValueError("LANGCHAIN_API_KEY is not set")
return callbacks

0 comments on commit 07659f2

Please sign in to comment.