Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/agents/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from pydantic.dataclasses import dataclass


@dataclass
class RequestUsage:
"""Usage details for a single API request."""

input_tokens: int
"""Input tokens for this individual request."""

output_tokens: int
"""Output tokens for this individual request."""

total_tokens: int
"""Total tokens (input + output) for this individual request."""

input_tokens_details: InputTokensDetails
"""Details about the input tokens for this individual request."""

output_tokens_details: OutputTokensDetails
"""Details about the output tokens for this individual request."""


@dataclass
class Usage:
requests: int = 0
Expand All @@ -27,7 +47,27 @@ class Usage:
total_tokens: int = 0
"""Total tokens sent and received, across all requests."""

request_usage_entries: list[RequestUsage] = field(default_factory=list)
"""List of RequestUsage entries for accurate per-request cost calculation.

Each call to `add()` automatically creates an entry in this list if the added usage
represents a new request (i.e., has non-zero tokens).

Example:
For a run that makes 3 API calls with 100K, 150K, and 80K input tokens each,
the aggregated `input_tokens` would be 330K, but `request_usage_entries` would
preserve the [100K, 150K, 80K] breakdown, which could be helpful for detailed
cost calculation or context window management.
"""

def add(self, other: "Usage") -> None:
"""Add another Usage object to this one, aggregating all fields.

This method automatically preserves request_usage_entries.

Args:
other: The Usage object to add to this one.
"""
self.requests += other.requests if other.requests else 0
self.input_tokens += other.input_tokens if other.input_tokens else 0
self.output_tokens += other.output_tokens if other.output_tokens else 0
Expand All @@ -41,3 +81,18 @@ def add(self, other: "Usage") -> None:
reasoning_tokens=self.output_tokens_details.reasoning_tokens
+ other.output_tokens_details.reasoning_tokens
)

# Automatically preserve request_usage_entries.
# If the other Usage represents a single request with tokens, record it.
if other.requests == 1 and other.total_tokens > 0:
individual_usage = RequestUsage(
input_tokens=other.input_tokens,
output_tokens=other.output_tokens,
total_tokens=other.total_tokens,
input_tokens_details=other.input_tokens_details,
output_tokens_details=other.output_tokens_details,
)
self.request_usage_entries.append(individual_usage)
elif other.request_usage_entries:
# If the other Usage already has individual request breakdowns, merge them.
self.request_usage_entries.extend(other.request_usage_entries)
219 changes: 218 additions & 1 deletion tests/test_usage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from agents.usage import Usage
from agents.usage import RequestUsage, Usage


def test_usage_add_aggregates_all_fields():
Expand Down Expand Up @@ -50,3 +50,220 @@ def test_usage_add_aggregates_with_none_values():
assert u1.total_tokens == 15
assert u1.input_tokens_details.cached_tokens == 4
assert u1.output_tokens_details.reasoning_tokens == 6


def test_request_usage_creation():
"""Test that RequestUsage is created correctly."""
request_usage = RequestUsage(
input_tokens=100,
output_tokens=200,
total_tokens=300,
input_tokens_details=InputTokensDetails(cached_tokens=10),
output_tokens_details=OutputTokensDetails(reasoning_tokens=20),
)

assert request_usage.input_tokens == 100
assert request_usage.output_tokens == 200
assert request_usage.total_tokens == 300
assert request_usage.input_tokens_details.cached_tokens == 10
assert request_usage.output_tokens_details.reasoning_tokens == 20


def test_usage_add_preserves_single_request():
"""Test that adding a single request Usage creates an RequestUsage entry."""
u1 = Usage()
u2 = Usage(
requests=1,
input_tokens=100,
input_tokens_details=InputTokensDetails(cached_tokens=10),
output_tokens=200,
output_tokens_details=OutputTokensDetails(reasoning_tokens=20),
total_tokens=300,
)

u1.add(u2)

# Should preserve the request usage details
assert len(u1.request_usage_entries) == 1
request_usage = u1.request_usage_entries[0]
assert request_usage.input_tokens == 100
assert request_usage.output_tokens == 200
assert request_usage.total_tokens == 300
assert request_usage.input_tokens_details.cached_tokens == 10
assert request_usage.output_tokens_details.reasoning_tokens == 20


def test_usage_add_ignores_zero_token_requests():
"""Test that zero-token requests don't create request_usage_entries."""
u1 = Usage()
u2 = Usage(
requests=1,
input_tokens=0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens=0,
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
total_tokens=0,
)

u1.add(u2)

# Should not create a request_usage_entry for zero tokens
assert len(u1.request_usage_entries) == 0


def test_usage_add_ignores_multi_request_usage():
"""Test that multi-request Usage objects don't create request_usage_entries."""
u1 = Usage()
u2 = Usage(
requests=3, # Multiple requests
input_tokens=100,
input_tokens_details=InputTokensDetails(cached_tokens=10),
output_tokens=200,
output_tokens_details=OutputTokensDetails(reasoning_tokens=20),
total_tokens=300,
)

u1.add(u2)

# Should not create a request usage entry for multi-request usage
assert len(u1.request_usage_entries) == 0


def test_usage_add_merges_existing_request_usage_entries():
"""Test that existing request_usage_entries are merged when adding Usage objects."""
# Create first usage with request_usage_entries
u1 = Usage()
u2 = Usage(
requests=1,
input_tokens=100,
input_tokens_details=InputTokensDetails(cached_tokens=10),
output_tokens=200,
output_tokens_details=OutputTokensDetails(reasoning_tokens=20),
total_tokens=300,
)
u1.add(u2)

# Create second usage with request_usage_entries
u3 = Usage(
requests=1,
input_tokens=50,
input_tokens_details=InputTokensDetails(cached_tokens=5),
output_tokens=75,
output_tokens_details=OutputTokensDetails(reasoning_tokens=10),
total_tokens=125,
)

u1.add(u3)

# Should have both request_usage_entries
assert len(u1.request_usage_entries) == 2

# First request
first = u1.request_usage_entries[0]
assert first.input_tokens == 100
assert first.output_tokens == 200
assert first.total_tokens == 300

# Second request
second = u1.request_usage_entries[1]
assert second.input_tokens == 50
assert second.output_tokens == 75
assert second.total_tokens == 125


def test_usage_add_with_pre_existing_request_usage_entries():
"""Test adding Usage objects that already have request_usage_entries."""
u1 = Usage()

# Create a usage with request_usage_entries
u2 = Usage(
requests=1,
input_tokens=100,
input_tokens_details=InputTokensDetails(cached_tokens=10),
output_tokens=200,
output_tokens_details=OutputTokensDetails(reasoning_tokens=20),
total_tokens=300,
)
u1.add(u2)

# Create another usage with request_usage_entries
u3 = Usage(
requests=1,
input_tokens=50,
input_tokens_details=InputTokensDetails(cached_tokens=5),
output_tokens=75,
output_tokens_details=OutputTokensDetails(reasoning_tokens=10),
total_tokens=125,
)

# Add u3 to u1
u1.add(u3)

# Should have both request_usage_entries
assert len(u1.request_usage_entries) == 2
assert u1.request_usage_entries[0].input_tokens == 100
assert u1.request_usage_entries[1].input_tokens == 50


def test_usage_request_usage_entries_default_empty():
"""Test that request_usage_entries defaults to an empty list."""
u = Usage()
assert u.request_usage_entries == []


def test_anthropic_cost_calculation_scenario():
"""Test a realistic scenario for Sonnet 4.5 cost calculation with 200K token thresholds."""
# Simulate 3 API calls: 100K, 150K, and 80K input tokens each
# None exceed 200K, so they should all use the lower pricing tier

usage = Usage()

# First request: 100K input tokens
req1 = Usage(
requests=1,
input_tokens=100_000,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens=50_000,
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
total_tokens=150_000,
)
usage.add(req1)

# Second request: 150K input tokens
req2 = Usage(
requests=1,
input_tokens=150_000,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens=75_000,
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
total_tokens=225_000,
)
usage.add(req2)

# Third request: 80K input tokens
req3 = Usage(
requests=1,
input_tokens=80_000,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens=40_000,
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
total_tokens=120_000,
)
usage.add(req3)

# Verify aggregated totals
assert usage.requests == 3
assert usage.input_tokens == 330_000 # 100K + 150K + 80K
assert usage.output_tokens == 165_000 # 50K + 75K + 40K
assert usage.total_tokens == 495_000 # 150K + 225K + 120K

# Verify request_usage_entries preservation
assert len(usage.request_usage_entries) == 3
assert usage.request_usage_entries[0].input_tokens == 100_000
assert usage.request_usage_entries[1].input_tokens == 150_000
assert usage.request_usage_entries[2].input_tokens == 80_000

# All request_usage_entries are under 200K threshold
for req in usage.request_usage_entries:
assert req.input_tokens < 200_000
assert req.output_tokens < 200_000