Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community : Add OpenAI prompt caching and reasoning tokens tracking #27135

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions libs/community/langchain_community/callbacks/openai_info.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
"""Callback Handler that prints to std out."""

import threading
from enum import Enum, auto
from typing import Any, Dict, List

from langchain_core._api import warn_deprecated
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult

MODEL_COST_PER_1K_TOKENS = {
# OpenAI o1-preview input
"o1-preview": 0.015,
"o1-preview-cached": 0.0075,
"o1-preview-2024-09-12": 0.015,
"o1-preview-2024-09-12-cached": 0.0075,
# OpenAI o1-preview output
"o1-preview-completion": 0.06,
"o1-preview-2024-09-12-completion": 0.06,
# OpenAI o1-mini input
"o1-mini": 0.003,
"o1-mini-cached": 0.0015,
"o1-mini-2024-09-12": 0.003,
"o1-mini-2024-09-12-cached": 0.0015,
# OpenAI o1-mini output
"o1-mini-completion": 0.012,
"o1-mini-2024-09-12-completion": 0.012,
# GPT-4o-mini input
"gpt-4o-mini": 0.00015,
"gpt-4o-mini-cached": 0.000075,
"gpt-4o-mini-2024-07-18": 0.00015,
"gpt-4o-mini-2024-07-18-cached": 0.000075,
# GPT-4o-mini output
"gpt-4o-mini-completion": 0.0006,
"gpt-4o-mini-2024-07-18-completion": 0.0006,
# GPT-4o input
"gpt-4o": 0.0025,
"gpt-4o-cached": 0.00125,
"gpt-4o-2024-05-13": 0.005,
"gpt-4o-2024-08-06": 0.0025,
"gpt-4o-2024-08-06-cached": 0.00125,
"gpt-4o-2024-11-20": 0.0025,
# GPT-4o output
"gpt-4o-completion": 0.01,
Expand Down Expand Up @@ -140,43 +150,73 @@
}


class TokenType(Enum):
"""Token type enum."""

PROMPT = auto()
PROMPT_CACHED = auto()
COMPLETION = auto()


def standardize_model_name(
model_name: str,
is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> str:
"""
Standardize the model name to a format that can be used in the OpenAI API.

Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Defaults to False. Deprecated in favor of ``token_type``.
token_type: Token type. Defaults to ``TokenType.PROMPT``.

Returns:
Standardized model name.

"""
if is_completion:
warn_deprecated(
since="0.3.13",
message=(
"is_completion is deprecated. Use token_type instead. Example:\n\n"
"from langchain_community.callbacks.openai_info import TokenType\n\n"
"standardize_model_name('gpt-4o', token_type=TokenType.COMPLETION)\n"
),
removal="1.0",
)
token_type = TokenType.COMPLETION
model_name = model_name.lower()
if ".ft-" in model_name:
model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
if ":ft-" in model_name:
model_name = model_name.split(":")[0] + "-finetuned-legacy"
if "ft:" in model_name:
model_name = model_name.split(":")[1] + "-finetuned"
if is_completion and (
if token_type == TokenType.COMPLETION and (
model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
or model_name.startswith("o1-")
or ("finetuned" in model_name and "legacy" not in model_name)
):
return model_name + "-completion"
if token_type == TokenType.PROMPT_CACHED and (
model_name.startswith("gpt-4o") or model_name.startswith("o1")
):
return model_name + "-cached"
else:
return model_name


def get_openai_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
model_name: str,
num_tokens: int,
is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Expand All @@ -185,12 +225,24 @@ def get_openai_token_cost_for_model(
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Defaults to False. Deprecated in favor of ``token_type``.
token_type: Token type. Defaults to ``TokenType.PROMPT``.

Returns:
Cost in USD.
"""
model_name = standardize_model_name(model_name, is_completion=is_completion)
if is_completion:
warn_deprecated(
since="0.3.13",
message=(
"is_completion is deprecated. Use token_type instead. Example:\n\n"
"from langchain_community.callbacks.openai_info import TokenType\n\n"
"get_openai_token_cost_for_model('gpt-4o', 10, token_type=TokenType.COMPLETION)\n" # noqa: E501
),
removal="1.0",
)
token_type = TokenType.COMPLETION
model_name = standardize_model_name(model_name, token_type=token_type)
if model_name not in MODEL_COST_PER_1K_TOKENS:
raise ValueError(
f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
Expand All @@ -204,7 +256,9 @@ class OpenAICallbackHandler(BaseCallbackHandler):

total_tokens: int = 0
prompt_tokens: int = 0
prompt_tokens_cached: int = 0
completion_tokens: int = 0
reasoning_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0

Expand All @@ -216,7 +270,9 @@ def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\t\tPrompt Tokens Cached: {self.prompt_tokens_cached}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"\t\tReasoning Tokens: {self.reasoning_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)
Expand Down Expand Up @@ -258,6 +314,10 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
else:
usage_metadata = None
response_metadata = None

prompt_tokens_cached = 0
reasoning_tokens = 0

if usage_metadata:
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
completion_tokens = usage_metadata["output_tokens"]
Expand All @@ -270,7 +330,12 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
model_name = standardize_model_name(
response.llm_output.get("model_name", "")
)

if "cache_read" in usage_metadata.get("input_token_details", {}):
prompt_tokens_cached = usage_metadata["input_token_details"][
"cache_read"
]
if "reasoning" in usage_metadata.get("output_token_details", {}):
reasoning_tokens = usage_metadata["output_token_details"]["reasoning"]
else:
if response.llm_output is None:
return None
Expand All @@ -287,11 +352,19 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
model_name = standardize_model_name(
response.llm_output.get("model_name", "")
)

if model_name in MODEL_COST_PER_1K_TOKENS:
uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached
uncached_prompt_cost = get_openai_token_cost_for_model(
model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT
)
cached_prompt_cost = get_openai_token_cost_for_model(
model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED
)
prompt_cost = uncached_prompt_cost + cached_prompt_cost
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, is_completion=True
model_name, completion_tokens, token_type=TokenType.COMPLETION
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
else:
completion_cost = 0
prompt_cost = 0
Expand All @@ -301,7 +374,9 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.prompt_tokens_cached += prompt_tokens_cached
self.completion_tokens += completion_tokens
self.reasoning_tokens += reasoning_tokens
self.successful_requests += 1

def __copy__(self) -> "OpenAICallbackHandler":
Expand Down
40 changes: 39 additions & 1 deletion libs/community/tests/unit_tests/callbacks/test_openai_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import numpy as np
import pytest
from langchain_core.outputs import LLMResult
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.utils.pydantic import get_fields

from langchain_community.callbacks import OpenAICallbackHandler
Expand Down Expand Up @@ -35,6 +36,43 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None:
assert handler.total_cost > 0


def test_on_llm_end_with_chat_generation(handler: OpenAICallbackHandler) -> None:
response = LLMResult(
generations=[
[
ChatGeneration(
text="Hello, world!",
message=AIMessage(
content="Hello, world!",
usage_metadata={
"input_tokens": 2,
"output_tokens": 2,
"total_tokens": 4,
"input_token_details": {
"cache_read": 1,
},
"output_token_details": {
"reasoning": 1,
},
},
),
)
]
],
llm_output={
"model_name": get_fields(BaseOpenAI)["model_name"].default,
},
)
handler.on_llm_end(response)
assert handler.successful_requests == 1
assert handler.total_tokens == 4
assert handler.prompt_tokens == 2
assert handler.prompt_tokens_cached == 1
assert handler.completion_tokens == 2
assert handler.reasoning_tokens == 1
assert handler.total_cost > 0


def test_on_llm_end_custom_model(handler: OpenAICallbackHandler) -> None:
response = LLMResult(
generations=[],
Expand Down
Loading