Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmd/root/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error {
return fmt.Errorf("failed to create remote client: %w", err)
}

sessTemplate := session.New()
// Attach agent metadata so remote runtimes can attribute usage correctly.
sessTemplate := session.New(session.WithAgentMetadata(agentName, ""))
sessTemplate.ToolsApproved = autoApprove
sess, err = remoteClient.CreateSession(ctx, sessTemplate)
if err != nil {
Expand All @@ -235,7 +236,11 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error {
}

// Create session first to get its ID for OAuth state encoding
sess = session.New(session.WithMaxIterations(agent.MaxIterations()))
sess = session.New(
// Provide agent metadata so local sessions report attribution like remote ones.
session.WithMaxIterations(agent.MaxIterations()),
session.WithAgentMetadata(agentName, ""),
)
sess.ToolsApproved = autoApprove

// Create local runtime with root session ID for OAuth state encoding
Expand Down
26 changes: 21 additions & 5 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,27 @@ type TokenUsageEvent struct {
}

type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
ContextLength int `json:"context_length"`
ContextLimit int `json:"context_limit"`
Cost float64 `json:"cost"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
ContextLength int `json:"context_length"`
ContextLimit int `json:"context_limit"`
Cost float64 `json:"cost"`
Breakdown []*SessionUsage `json:"breakdown,omitempty"` // Per-session usage rows for hierarchical displays.
ActiveSessions []string `json:"active_sessions,omitempty"` // IDs of sessions currently streaming tokens.
}

// SessionUsage captures token and cost totals for a specific session (and its place in the hierarchy).
type SessionUsage struct {
SessionID string `json:"session_id"`
AgentName string `json:"agent_name"`
Title string `json:"title,omitempty"`
ParentSessionID string `json:"parent_session_id,omitempty"`
Depth int `json:"depth"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
Cost float64 `json:"cost"`
ContextLimit int `json:"context_limit,omitempty"`
Active bool `json:"active"`
}

func TokenUsage(inputTokens, outputTokens, contextLength, contextLimit int, cost float64) Event {
Expand Down
204 changes: 190 additions & 14 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ type LocalRuntime struct {
elicitationRequestCh chan ElicitationResult // Channel for receiving elicitation responses
elicitationEventsChannel chan Event // Current events channel for sending elicitation requests
elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel
usageTracker *usageTracker // Aggregates token usage across active sessions.
}

type streamResult struct {
Expand Down Expand Up @@ -158,6 +159,7 @@ func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
modelsStore: modelsStore,
sessionCompaction: true,
managedOAuth: true,
usageTracker: newUsageTracker(), // Start tracking usage immediately for the first session.
}

for _, opt := range opts {
Expand Down Expand Up @@ -199,6 +201,11 @@ func (r *LocalRuntime) registerDefaultTools() {
func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.Session, events chan Event) {
defer close(events)

if r.usageTracker != nil {
// Mark the session as inactive once streaming finishes.
r.usageTracker.markActive(sess.ID, false)
}

events <- StreamStopped(sess.ID, r.currentAgent)

telemetry.RecordSessionEnd(ctx)
Expand All @@ -208,6 +215,64 @@ func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.S
}
}

// emitUsageEvent sends a consolidated usage snapshot (including breakdown) to the UI layer.
func (r *LocalRuntime) emitUsageEvent(sess *session.Session, contextLimit int, events chan Event) {
if sess == nil || events == nil {
return
}

usage := &Usage{
ContextLength: sess.TotalInputTokens + sess.TotalOutputTokens,
ContextLimit: contextLimit,
InputTokens: sess.TotalInputTokens,
OutputTokens: sess.TotalOutputTokens,
Cost: sess.TotalCost,
}

if r.usageTracker != nil {
// Pull a fresh snapshot so totals reflect all sessions, including children.
summary := r.usageTracker.snapshot(contextLimit)
usage.InputTokens = summary.TotalInput
usage.OutputTokens = summary.TotalOutput
usage.Cost = summary.TotalCost
usage.ContextLength = summary.TotalInput + summary.TotalOutput
if len(summary.Rows) > 0 {
// Include per-session rows so the TUI can render the hierarchy.
usage.Breakdown = summary.Rows
}
if len(summary.ActiveSessions) > 0 {
// Surface the live session IDs for percent calculations.
usage.ActiveSessions = summary.ActiveSessions
}
// ContextLimit semantics:
// - If exactly one session is active, prefer that session's context limit.
// - If multiple (or zero) active sessions, set to 0 to avoid misleading percentages.
if len(summary.ActiveSessions) == 1 {
activeID := summary.ActiveSessions[0]
limit := 0
for _, row := range summary.Rows {
if row.SessionID == activeID && row.ContextLimit > 0 {
limit = row.ContextLimit
break
}
}
if limit > 0 {
usage.ContextLimit = limit
} else if summary.ContextLimit > 0 {
usage.ContextLimit = summary.ContextLimit
}
} else {
// Ambiguous across multiple sessions; signal UI to suppress percent
usage.ContextLimit = 0
}
}

events <- &TokenUsageEvent{
Type: "token_usage",
Usage: usage,
}
}

// RunStream starts the agent's interaction loop and returns a channel of events
func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event {
slog.Debug("Starting runtime stream", "agent", r.currentAgent, "session_id", sess.ID)
Expand All @@ -231,6 +296,15 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

r.emitAgentWarnings(a, events)

if sess.AgentName == "" {
// Ensure we capture which agent owns the session for attribution.
sess.AgentName = a.Name()
}
if r.usageTracker != nil {
// Register and activate the session so usage snapshots include it.
r.usageTracker.registerSession(sess.ID, sess.AgentName, sess.ParentSessionID, sess.Title, 0)
r.usageTracker.markActive(sess.ID, true)
}
for _, toolset := range a.ToolSets() {
toolset.SetElicitationHandler(r.elicitationHandler)
toolset.SetOAuthSuccessHandler(func() {
Expand Down Expand Up @@ -311,6 +385,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
m, err := r.modelsStore.GetModel(ctx, modelID)
if err != nil {
slog.Debug("Failed to get model definition", "error", err)
} else if r.usageTracker != nil && m != nil {
// Capture provider context limits so percentages are meaningful.
r.usageTracker.registerSession(sess.ID, "", "", "", m.Limit.Context)
}

slog.Debug("Creating chat completion stream", "agent", a.Name())
Expand Down Expand Up @@ -374,7 +451,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
if m != nil {
contextLimit = m.Limit.Context
}
events <- TokenUsage(sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
r.emitUsageEvent(sess, contextLimit, events)

if m != nil && r.sessionCompaction {
if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) {
Expand All @@ -383,7 +460,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
if len(res.Calls) == 0 {
events <- SessionCompaction(sess.ID, "start", r.currentAgent)
r.Summarize(ctx, sess, events)
events <- TokenUsage(sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
r.emitUsageEvent(sess, contextLimit, events)
events <- SessionCompaction(sess.ID, "completed", r.currentAgent)
}
}
Expand All @@ -397,7 +474,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
if sess.InputTokens+sess.OutputTokens > int(float64(contextLimit)*0.9) {
events <- SessionCompaction(sess.ID, "start", r.currentAgent)
r.Summarize(ctx, sess, events)
events <- TokenUsage(sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
r.emitUsageEvent(sess, contextLimit, events)
events <- SessionCompaction(sess.ID, "completed", r.currentAgent)
}
}
Expand Down Expand Up @@ -520,13 +597,26 @@ func (r *LocalRuntime) Run(ctx context.Context, sess *session.Session) ([]sessio
func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStream, a *agent.Agent, agentTools []tools.Tool, sess *session.Session, m *modelsdev.Model, events chan Event) (streamResult, error) {
defer stream.Close()

// Start each provider call with a clean slate so deltas stay accurate.
sess.ResetUsageTracking()

var fullContent strings.Builder
var fullReasoningContent strings.Builder
var thinkingSignature string
var toolCalls []tools.ToolCall
// Track which tool call indices we've already emitted partial events for
emittedPartialEvents := make(map[string]bool)

var lastPromptTokens int
var lastCompletionTokens int
var lastCachedInputTokens int
var lastCachedOutputTokens int

// Accumulate telemetry totals for this model call; emit once on completion
var telemetryIn int
var telemetryOut int
var telemetryCost float64

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
Expand All @@ -537,28 +627,82 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
}

if response.Usage != nil {
if m != nil {
sess.Cost += (float64(response.Usage.InputTokens)*m.Cost.Input +
float64(response.Usage.OutputTokens+response.Usage.ReasoningTokens)*m.Cost.Output +
float64(response.Usage.CachedInputTokens)*m.Cost.CacheRead +
float64(response.Usage.CachedOutputTokens)*m.Cost.CacheWrite) / 1e6
// Convert absolute token counters into per-update deltas.
promptTokensAbs := response.Usage.InputTokens
completionTokensAbs := response.Usage.OutputTokens + response.Usage.ReasoningTokens
cachedInputTokensAbs := response.Usage.CachedInputTokens
cachedOutputTokensAbs := response.Usage.CachedOutputTokens

inputTokensAbs := promptTokensAbs + cachedInputTokensAbs
outputTokensAbs := completionTokensAbs + cachedOutputTokensAbs

inputDelta := inputTokensAbs - sess.InputTokens
if inputDelta < 0 {
inputDelta = 0
}
outputDelta := outputTokensAbs - sess.OutputTokens
if outputDelta < 0 {
outputDelta = 0
}

sess.InputTokens = response.Usage.InputTokens + response.Usage.CachedInputTokens
sess.OutputTokens = response.Usage.OutputTokens + response.Usage.CachedOutputTokens + response.Usage.ReasoningTokens
promptDelta := promptTokensAbs - lastPromptTokens
if promptDelta < 0 {
promptDelta = 0
}
completionDelta := completionTokensAbs - lastCompletionTokens
if completionDelta < 0 {
completionDelta = 0
}
cachedInputDelta := cachedInputTokensAbs - lastCachedInputTokens
if cachedInputDelta < 0 {
cachedInputDelta = 0
}
cachedOutputDelta := cachedOutputTokensAbs - lastCachedOutputTokens
if cachedOutputDelta < 0 {
cachedOutputDelta = 0
}

lastPromptTokens = promptTokensAbs
lastCompletionTokens = completionTokensAbs
lastCachedInputTokens = cachedInputTokensAbs
lastCachedOutputTokens = cachedOutputTokensAbs

modelName := "unknown"
var costDelta float64
if m != nil {
modelName = m.Name
costDelta = (float64(promptDelta)*m.Cost.Input +
float64(completionDelta)*m.Cost.Output +
float64(cachedInputDelta)*m.Cost.CacheRead +
float64(cachedOutputDelta)*m.Cost.CacheWrite) / 1e6
}

if inputDelta > 0 || outputDelta > 0 || costDelta > 0 {
// Persist the delta so cumulative totals and UI breakdowns stay in sync.
sess.AddUsageDelta(inputDelta, outputDelta, costDelta)
if r.usageTracker != nil {
// Mirror the delta in the tracker for cross-session summaries.
r.usageTracker.addDelta(sess.ID, inputDelta, outputDelta, costDelta)
}
// Accumulate totals for a single end-of-call telemetry event
telemetryIn += inputDelta
telemetryOut += outputDelta
telemetryCost += costDelta
}
telemetry.RecordTokenUsage(ctx, modelName, int64(response.Usage.InputTokens), int64(response.Usage.OutputTokens+response.Usage.ReasoningTokens), sess.Cost)
}

if len(response.Choices) == 0 {
continue
}
choice := response.Choices[0]
if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength {
// Emit a single telemetry record for this model call, if any usage was recorded
if telemetryIn > 0 || telemetryOut > 0 || telemetryCost > 0 {
// Emit one OTEL record per completion to avoid flooding metrics.
modelName := "unknown"
if m != nil {
modelName = m.Name
}
telemetry.RecordTokenUsage(ctx, modelName, int64(telemetryIn), int64(telemetryOut), telemetryCost)
}
return streamResult{
Calls: toolCalls,
Content: fullContent.String(),
Expand Down Expand Up @@ -648,6 +792,16 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
// If the stream completed without producing any content or tool calls, likely because of a token limit, stop to avoid breaking the request loop
// NOTE(krissetto): this can likely be removed once compaction works properly with all providers (aka dmr)
stoppedDueToNoOutput := fullContent.Len() == 0 && len(toolCalls) == 0

// Stream completed without an explicit finish reason; emit telemetry once if usage was recorded
if telemetryIn > 0 || telemetryOut > 0 || telemetryCost > 0 {
// When providers end without a finish reason, still flush the accumulated metrics.
modelName := "unknown"
if m != nil {
modelName = m.Name
}
telemetry.RecordTokenUsage(ctx, modelName, int64(telemetryIn), int64(telemetryOut), telemetryCost)
}
return streamResult{
Calls: toolCalls,
Content: fullContent.String(),
Expand Down Expand Up @@ -963,10 +1117,15 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses
session.WithSystemMessage(memberAgentTask),
session.WithImplicitUserMessage("", "Follow the default instructions"),
session.WithMaxIterations(child.MaxIterations()),
session.WithAgentMetadata(params.Agent, sess.ID),
)
s.SendUserMessage = false
s.Title = "Transferred task"
s.ToolsApproved = sess.ToolsApproved
if r.usageTracker != nil {
// Track the delegated session so its usage appears under the parent.
r.usageTracker.registerSession(s.ID, s.AgentName, s.ParentSessionID, s.Title, 0)
}

for event := range r.RunStream(ctx, s) {
evts <- event
Expand All @@ -978,7 +1137,20 @@ func (r *LocalRuntime) handleTaskTransfer(ctx context.Context, sess *session.Ses
}

sess.ToolsApproved = s.ToolsApproved
sess.Cost += s.Cost
// Avoid double-counting: roll the child totals into the parent instead of adding raw costs.
sess.MergeChildUsage(s)

contextLimit := 0
if parentAgent, _ := r.team.Agent(ca); parentAgent != nil {
if model := parentAgent.Model(); model != nil {
if modelDef, err := r.modelsStore.GetModel(ctx, model.ID()); err == nil && modelDef != nil {
contextLimit = modelDef.Limit.Context
}
}
}

// Publishing a usage snapshot here keeps the UI in sync after transfers.
r.emitUsageEvent(sess, contextLimit, evts)

sess.AddSubSession(s)

Expand Down Expand Up @@ -1041,6 +1213,10 @@ func (r *LocalRuntime) generateSessionTitle(ctx context.Context, sess *session.S
return
}
sess.Title = title
if r.usageTracker != nil {
// Refresh tracker metadata so new titles show up in the sidebar.
r.usageTracker.registerSession(sess.ID, "", "", sess.Title, 0)
}
slog.Debug("Generated session title", "session_id", sess.ID, "title", title)
events <- SessionTitle(sess.ID, title, r.currentAgent)
}
Expand Down
Loading