From 1d2822677ec1d47bded75338bcb49db5224db712 Mon Sep 17 00:00:00 2001 From: Nir Gazit Date: Wed, 2 Oct 2024 17:10:26 -0700 Subject: [PATCH] fix(langchain): token usage reporting (#2074) --- .../langchain/callback_handler.py | 51 +++++++++++++++-- .../poetry.lock | 14 ++--- .../tests/test_chains.py | 57 +++++++++++++------ .../tests/test_lcel.py | 14 +++-- .../tests/test_llms.py | 26 +++++++-- 5 files changed, 120 insertions(+), 42 deletions(-) diff --git a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py index 86ca081c2..819619b99 100644 --- a/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py +++ b/packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py @@ -153,9 +153,30 @@ def _set_chat_response(span: Span, response: LLMResult) -> None: if not should_send_prompts(): return + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + i = 0 for generations in response.generations: for generation in generations: + if ( + hasattr(generation, "message") + and hasattr(generation.message, "usage_metadata") + and generation.message.usage_metadata is not None + ): + input_tokens += ( + generation.message.usage_metadata.get("input_tokens") + or generation.message.usage_metadata.get("prompt_tokens") + or 0 + ) + output_tokens += ( + generation.message.usage_metadata.get("output_tokens") + or generation.message.usage_metadata.get("completion_tokens") + or 0 + ) + total_tokens = input_tokens + output_tokens + prefix = f"{SpanAttributes.LLM_COMPLETIONS}.{i}" if hasattr(generation, "text") and generation.text != "": span.set_attribute( @@ -201,6 +222,20 @@ def _set_chat_response(span: Span, response: LLMResult) -> None: ) i += 1 + if input_tokens > 0 or output_tokens > 0 or total_tokens > 0: + span.set_attribute( + SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + input_tokens, + ) + span.set_attribute( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + output_tokens, + ) + span.set_attribute( + SpanAttributes.LLM_USAGE_TOTAL_TOKENS, + total_tokens, + ) + class TraceloopCallbackHandler(BaseCallbackHandler): def __init__(self, tracer: Tracer) -> None: @@ -481,13 +516,19 @@ def on_llm_end( span = self._get_span(run_id) - token_usage = (response.llm_output or {}).get("token_usage") + token_usage = (response.llm_output or {}).get("token_usage") or ( + response.llm_output or {} + ).get("usage") if token_usage is not None: - prompt_tokens = token_usage.get("prompt_tokens") or token_usage.get( - "input_token_count" + prompt_tokens = ( + token_usage.get("prompt_tokens") + or token_usage.get("input_token_count") + or token_usage.get("input_tokens") ) - completion_tokens = token_usage.get("completion_tokens") or token_usage.get( - "generated_token_count" + completion_tokens = ( + token_usage.get("completion_tokens") + or token_usage.get("generated_token_count") + or token_usage.get("output_tokens") ) total_tokens = token_usage.get("total_tokens") or ( prompt_tokens + completion_tokens diff --git a/packages/opentelemetry-instrumentation-langchain/poetry.lock b/packages/opentelemetry-instrumentation-langchain/poetry.lock index 7fa033a46..6ab60eb12 100644 --- a/packages/opentelemetry-instrumentation-langchain/poetry.lock +++ b/packages/opentelemetry-instrumentation-langchain/poetry.lock @@ -274,8 +274,8 @@ files = [ jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, ] [package.extras] @@ -1046,8 +1046,8 @@ langchain-core = ">=0.3.0,<0.4.0" langchain-text-splitters = ">=0.3.0,<0.4.0" langsmith = ">=0.1.17,<0.2.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" @@ -1112,8 +1112,8 @@ langchain = ">=0.3.0,<0.4.0" langchain-core = ">=0.3.0,<0.4.0" langsmith = ">=0.1.112,<0.2.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" @@ -1137,8 +1137,8 @@ jsonpatch = ">=1.33,<2.0" langsmith = ">=0.1.117,<0.2.0" packaging = ">=23.2,<25" pydantic = [ - {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, + {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, ] PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" @@ -1220,8 +1220,8 @@ files = [ httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ - {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, + {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, ] requests = ">=2,<3" @@ -1672,9 +1672,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1760,8 +1760,8 @@ files = [ annotated-types = ">=0.6.0" pydantic-core = "2.23.3" typing-extensions = [ - {version = ">=4.6.1", markers = "python_version < \"3.13\""}, {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, ] [package.extras] diff --git a/packages/opentelemetry-instrumentation-langchain/tests/test_chains.py b/packages/opentelemetry-instrumentation-langchain/tests/test_chains.py index ef0f58c6d..3909418a2 100644 --- a/packages/opentelemetry-instrumentation-langchain/tests/test_chains.py +++ b/packages/opentelemetry-instrumentation-langchain/tests/test_chains.py @@ -53,15 +53,32 @@ def test_sequential_chain(exporter): "SequentialChain.workflow", ] == [span.name for span in spans] - workflow_span = next(span for span in spans if span.name == "SequentialChain.workflow") - task_spans = [span for span in spans if span.name in ["synopsis.task", "LLMChain.task"]] + workflow_span = next( + span for span in spans if span.name == "SequentialChain.workflow" + ) + task_spans = [ + span for span in spans if span.name in ["synopsis.task", "LLMChain.task"] + ] llm_spans = [span for span in spans if span.name == "OpenAI.completion"] assert workflow_span.attributes[SpanAttributes.TRACELOOP_SPAN_KIND] == "workflow" - assert workflow_span.attributes[SpanAttributes.TRACELOOP_ENTITY_NAME] == "SequentialChain" - assert all(span.attributes[SpanAttributes.TRACELOOP_SPAN_KIND] == "task" for span in task_spans) - assert all(span.attributes[SpanAttributes.TRACELOOP_WORKFLOW_NAME] == "SequentialChain" for span in spans) - assert all(span.attributes[SpanAttributes.TRACELOOP_ENTITY_PATH] in ["synopsis", "LLMChain"] for span in llm_spans) + assert ( + workflow_span.attributes[SpanAttributes.TRACELOOP_ENTITY_NAME] + == "SequentialChain" + ) + assert all( + span.attributes[SpanAttributes.TRACELOOP_SPAN_KIND] == "task" + for span in task_spans + ) + assert all( + span.attributes[SpanAttributes.TRACELOOP_WORKFLOW_NAME] == "SequentialChain" + for span in spans + ) + assert all( + span.attributes[SpanAttributes.TRACELOOP_ENTITY_PATH] + in ["synopsis", "LLMChain"] + for span in llm_spans + ) synopsis_span = next(span for span in spans if span.name == "synopsis.task") review_span = next(span for span in spans if span.name == "LLMChain.task") @@ -200,12 +217,14 @@ def test_stream(exporter): chunks = list(runnable.stream({"product": "colorful socks"})) spans = exporter.get_finished_spans() - assert [ - "PromptTemplate.task", - "ChatCohere.chat", - "StrOutputParser.task", - "RunnableSequence.workflow", - ] == [span.name for span in spans] + assert set( + [ + "PromptTemplate.task", + "StrOutputParser.task", + "ChatCohere.chat", + "RunnableSequence.workflow", + ] + ) == set([span.name for span in spans]) assert len(chunks) == 62 @@ -223,10 +242,12 @@ async def test_astream(exporter): chunks.append(chunk) spans = exporter.get_finished_spans() - assert [ - "PromptTemplate.task", - "ChatCohere.chat", - "StrOutputParser.task", - "RunnableSequence.workflow", - ] == [span.name for span in spans] + assert set( + [ + "PromptTemplate.task", + "ChatCohere.chat", + "StrOutputParser.task", + "RunnableSequence.workflow", + ] + ) == set([span.name for span in spans]) assert len(chunks) == 144 diff --git a/packages/opentelemetry-instrumentation-langchain/tests/test_lcel.py b/packages/opentelemetry-instrumentation-langchain/tests/test_lcel.py index 498e84c22..d1b27a918 100644 --- a/packages/opentelemetry-instrumentation-langchain/tests/test_lcel.py +++ b/packages/opentelemetry-instrumentation-langchain/tests/test_lcel.py @@ -36,12 +36,14 @@ class Joke(BaseModel): spans = exporter.get_finished_spans() - assert [ - "ChatPromptTemplate.task", - "ChatOpenAI.chat", - "JsonOutputFunctionsParser.task", - "ThisIsATestChain.workflow", - ] == [span.name for span in spans] + assert set( + [ + "ChatPromptTemplate.task", + "JsonOutputFunctionsParser.task", + "ChatOpenAI.chat", + "ThisIsATestChain.workflow", + ] + ) == set([span.name for span in spans]) workflow_span = next( span for span in spans if span.name == "ThisIsATestChain.workflow" diff --git a/packages/opentelemetry-instrumentation-langchain/tests/test_llms.py b/packages/opentelemetry-instrumentation-langchain/tests/test_llms.py index 3b4e316b2..c78271644 100644 --- a/packages/opentelemetry-instrumentation-langchain/tests/test_llms.py +++ b/packages/opentelemetry-instrumentation-langchain/tests/test_llms.py @@ -90,6 +90,9 @@ def test_openai(exporter): assert ( openai_span.attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.role"] ) == "assistant" + assert openai_span.attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] == 24 + assert openai_span.attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] == 26 + assert openai_span.attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] == 50 @pytest.mark.vcr @@ -113,12 +116,14 @@ class Joke(BaseModel): spans = exporter.get_finished_spans() - assert [ - "ChatPromptTemplate.task", - "ChatOpenAI.chat", - "JsonOutputFunctionsParser.task", - "RunnableSequence.workflow", - ] == [span.name for span in spans] + assert set( + [ + "ChatPromptTemplate.task", + "JsonOutputFunctionsParser.task", + "ChatOpenAI.chat", + "RunnableSequence.workflow", + ] + ) == set([span.name for span in spans]) openai_span = next(span for span in spans if span.name == "ChatOpenAI.chat") @@ -169,6 +174,9 @@ class Joke(BaseModel): ) == response ) + assert openai_span.attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] == 76 + assert openai_span.attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] == 35 + assert openai_span.attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] == 111 @pytest.mark.vcr @@ -214,6 +222,9 @@ def test_anthropic(exporter): assert ( anthropic_span.attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.role"] ) == "assistant" + assert anthropic_span.attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] == 19 + assert anthropic_span.attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] == 22 + assert anthropic_span.attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] == 41 output = json.loads( workflow_span.attributes[SpanAttributes.TRACELOOP_ENTITY_OUTPUT] ) @@ -286,6 +297,9 @@ def test_bedrock(exporter): assert ( bedrock_span.attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.role"] ) == "assistant" + assert bedrock_span.attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] == 16 + assert bedrock_span.attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] == 27 + assert bedrock_span.attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] == 43 output = json.loads( workflow_span.attributes[SpanAttributes.TRACELOOP_ENTITY_OUTPUT] )