Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: propagate prompt management attributes to llm spans #109

Merged
merged 6 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/traceloop-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ addopts = "--cov --cov-report html:'../../coverage/packages/traceloop-sdk/html'

[tool.poetry]
name = "traceloop-sdk"
version = "0.0.62"
version = "0.0.63"
description = "Traceloop Software Development Kit (SDK) for Python"
authors = [
"Gal Kleinman <gal@traceloop.com>",
Expand Down
6 changes: 5 additions & 1 deletion packages/traceloop-sdk/traceloop/sdk/prompts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from jinja2 import Environment, meta
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential, retry_if_exception

from traceloop.sdk import base_url
from traceloop.sdk.prompts.model import Prompt, PromptVersion, TemplateEngine
from traceloop.sdk.prompts.registry import PromptRegistry
from traceloop.sdk.tracing.tracing import set_prompt_tracing_context

MAX_RETRIES = os.getenv("TRACELOOP_PROMPT_MANAGER_MAX_RETRIES") or 3
POLLING_INTERVAL = os.getenv("TRACELOOP_PROMPT_MANAGER_POLLING_INTERVAL") or 5
PROMPTS_ENDPOINT = "https://app.traceloop.com/api/prompts"
PROMPTS_ENDPOINT = f"{base_url()}/api/prompts"


def get_effective_version(prompt: Prompt) -> PromptVersion:
Expand Down Expand Up @@ -62,6 +64,8 @@ def render_prompt(self, key: str, **args):
params_dict.update(prompt_version.llm_config)
params_dict.pop("mode")

set_prompt_tracing_context(prompt.key, prompt_version.version, prompt_version.name)

return params_dict

def render_messages(self, prompt_version: PromptVersion, **args):
Expand Down
27 changes: 25 additions & 2 deletions packages/traceloop-sdk/traceloop/sdk/tracing/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

class TracerWrapper(object):
def __new__(
cls, disable_batch=False, exporter: SpanExporter = None
cls, disable_batch=False, exporter: SpanExporter = None
) -> "TracerWrapper":
if not hasattr(cls, "instance"):
obj = cls.instance = super(TracerWrapper, cls).__new__(cls)
Expand Down Expand Up @@ -69,7 +69,7 @@ def exit_handler(self):

@staticmethod
def set_static_params(
app_name: str, endpoint: str, headers: dict[str, str]
app_name: str, endpoint: str, headers: dict[str, str]
) -> None:
TracerWrapper.app_name = app_name
TracerWrapper.endpoint = endpoint
Expand Down Expand Up @@ -109,6 +109,12 @@ def set_workflow_name(workflow_name: str) -> None:
attach(set_value("workflow_name", workflow_name))


def set_prompt_tracing_context(key: str, version: int, version_name: str) -> None:
attach(set_value("prompt_key", key))
attach(set_value("prompt_version", version))
attach(set_value("prompt_version_name", version_name))


def span_processor_on_start(span, parent_context):
workflow_name = get_value("workflow_name")
if workflow_name is not None:
Expand All @@ -123,6 +129,23 @@ def span_processor_on_start(span, parent_context):
for key, value in association_properties.items():
span.set_attribute(f"{SpanAttributes.TRACELOOP_ASSOCIATION_PROPERTIES}.{key}", value)

if is_llm_span(span):
prompt_key = get_value("prompt_key")
if prompt_key is not None:
span.set_attribute("traceloop.prompt.key", prompt_key)

prompt_version = get_value("prompt_version")
if prompt_version is not None:
span.set_attribute("traceloop.prompt.version", prompt_version)

prompt_version_name = get_value("prompt_version_name")
if prompt_version_name is not None:
span.set_attribute("traceloop.prompt.version_name", prompt_version_name)


def is_llm_span(span) -> bool:
return span.attributes.get(SpanAttributes.LLM_REQUEST_TYPE) is not None


def init_spans_exporter(api_endpoint: str, headers: dict[str, str]) -> SpanExporter:
if "http" in api_endpoint.lower() or "https" in api_endpoint.lower():
Expand Down