Skip to content

Commit

Permalink
[MLOB-2098] feat(llmobs): record bedrock token counts (#5152)
Browse files Browse the repository at this point in the history
* working version by patching deserializedr

* wip

* cleanup

* use tokens on response directly if available

* make it run on all command types if available

* make token extraction cleaner

* test output

* parseint headers

* remove comment

* rename channel

* cleanup

* simpler instance patching

* fmt

* Update packages/datadog-instrumentations/src/helpers/hooks.js

* check subscribers
  • Loading branch information
sabrenner authored Jan 31, 2025
1 parent 34c3763 commit f01d385
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 35 deletions.
16 changes: 16 additions & 0 deletions packages/datadog-instrumentations/src/aws-sdk.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ function wrapRequest (send) {
}
}

function wrapDeserialize (deserialize, channelSuffix) {
const headersCh = channel(`apm:aws:response:deserialize:${channelSuffix}`)

return function (response) {
if (headersCh.hasSubscribers) {
headersCh.publish({ headers: response.headers })
}

return deserialize.apply(this, arguments)
}
}

function wrapSmithySend (send) {
return function (command, ...args) {
const cb = args[args.length - 1]
Expand All @@ -61,6 +73,10 @@ function wrapSmithySend (send) {
const responseStartChannel = channel(`apm:aws:response:start:${channelSuffix}`)
const responseFinishChannel = channel(`apm:aws:response:finish:${channelSuffix}`)

if (typeof command.deserialize === 'function') {
shimmer.wrap(command, 'deserialize', deserialize => wrapDeserialize(deserialize, channelSuffix))
}

return innerAr.runInAsyncScope(() => {
startCh.publish({
serviceIdentifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,23 @@ const PROVIDER = {
}

class Generation {
constructor ({ message = '', finishReason = '', choiceId = '' } = {}) {
constructor ({
message = '',
finishReason = '',
choiceId = '',
role,
inputTokens,
outputTokens
} = {}) {
// stringify message as it could be a single generated message as well as a list of embeddings
this.message = typeof message === 'string' ? message : JSON.stringify(message) || ''
this.finishReason = finishReason || ''
this.choiceId = choiceId || undefined
this.role = role
this.usage = {
inputTokens,
outputTokens
}
}
}

Expand Down Expand Up @@ -202,9 +214,12 @@ function extractTextAndResponseReason (response, provider, modelName) {
if (generations.length > 0) {
const generation = generations[0]
return new Generation({
message: generation.message,
message: generation.message.content,
finishReason: generation.finish_reason,
choiceId: shouldSetChoiceIds ? generation.id : undefined
choiceId: shouldSetChoiceIds ? generation.id : undefined,
role: generation.message.role,
inputTokens: body.usage?.prompt_tokens,
outputTokens: body.usage?.completion_tokens
})
}
}
Expand All @@ -214,7 +229,9 @@ function extractTextAndResponseReason (response, provider, modelName) {
return new Generation({
message: completion.data?.text,
finishReason: completion?.finishReason,
choiceId: shouldSetChoiceIds ? completion?.id : undefined
choiceId: shouldSetChoiceIds ? completion?.id : undefined,
inputTokens: body.usage?.prompt_tokens,
outputTokens: body.usage?.completion_tokens
})
}
return new Generation()
Expand All @@ -226,7 +243,12 @@ function extractTextAndResponseReason (response, provider, modelName) {
const results = body.results || []
if (results.length > 0) {
const result = results[0]
return new Generation({ message: result.outputText, finishReason: result.completionReason })
return new Generation({
message: result.outputText,
finishReason: result.completionReason,
inputTokens: body.inputTextTokenCount,
outputTokens: result.tokenCount
})
}
break
}
Expand All @@ -252,7 +274,12 @@ function extractTextAndResponseReason (response, provider, modelName) {
break
}
case PROVIDER.META: {
return new Generation({ message: body.generation, finishReason: body.stop_reason })
return new Generation({
message: body.generation,
finishReason: body.stop_reason,
inputTokens: body.prompt_token_count,
outputTokens: body.generation_token_count
})
}
case PROVIDER.MISTRAL: {
const mistralGenerations = body.outputs || []
Expand Down
56 changes: 37 additions & 19 deletions packages/datadog-plugin-aws-sdk/test/fixtures/bedrockruntime.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,22 @@ bedrockruntime.models = [
},
response: {
inputTextTokenCount: 7,
results: {
inputTextTokenCount: 7,
results: [
{
tokenCount: 35,
outputText: '\n' +
'Paris is the capital of France. France is a country that is located in Western Europe. ' +
'Paris is one of the most populous cities in the European Union. ',
completionReason: 'FINISH'
}
]
}
}
results: [{
tokenCount: 35,
outputText: '\n' +
'Paris is the capital of France. France is a country that is located in Western Europe. ' +
'Paris is one of the most populous cities in the European Union. ',
completionReason: 'FINISH'
}]
},
usage: {
inputTokens: 7,
outputTokens: 35,
totalTokens: 42
},
output: '\n' +
'Paris is the capital of France. France is a country that is located in Western Europe. ' +
'Paris is one of the most populous cities in the European Union. '
},
{
provider: PROVIDER.AI21,
Expand Down Expand Up @@ -79,7 +82,14 @@ bedrockruntime.models = [
completion_tokens: 7,
total_tokens: 17
}
}
},
usage: {
inputTokens: 10,
outputTokens: 7,
totalTokens: 17
},
output: 'The capital of France is Paris.',
outputRole: 'assistant'
},
{
provider: PROVIDER.ANTHROPIC,
Expand All @@ -97,7 +107,8 @@ bedrockruntime.models = [
completion: ' Paris is the capital of France.',
stop_reason: 'stop_sequence',
stop: '\n\nHuman:'
}
},
output: ' Paris is the capital of France.'
},
{
provider: PROVIDER.COHERE,
Expand All @@ -120,8 +131,8 @@ bedrockruntime.models = [
}
],
prompt: 'What is the capital of France?'
}

},
output: ' The capital of France is Paris. \n'
},
{
provider: PROVIDER.META,
Expand All @@ -138,7 +149,13 @@ bedrockruntime.models = [
prompt_token_count: 10,
generation_token_count: 7,
stop_reason: 'stop'
}
},
usage: {
inputTokens: 10,
outputTokens: 7,
totalTokens: 17
},
output: '\n\nThe capital of France is Paris.'
},
{
provider: PROVIDER.MISTRAL,
Expand All @@ -158,7 +175,8 @@ bedrockruntime.models = [
stop_reason: 'stop'
}
]
}
},
output: 'The capital of France is Paris.'
}
]
bedrockruntime.modelConfig = {
Expand Down
51 changes: 48 additions & 3 deletions packages/dd-trace/src/llmobs/plugins/bedrockruntime.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ const {
parseModelId
} = require('../../../../datadog-plugin-aws-sdk/src/services/bedrockruntime/utils')

const enabledOperations = ['invokeModel']
const ENABLED_OPERATIONS = ['invokeModel']

const requestIdsToTokens = {}

class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
constructor () {
Expand All @@ -18,7 +20,7 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
const request = response.request
const operation = request.operation
// avoids instrumenting other non supported runtime operations
if (!enabledOperations.includes(operation)) {
if (!ENABLED_OPERATIONS.includes(operation)) {
return
}
const { modelProvider, modelName } = parseModelId(request.params.modelId)
Expand All @@ -30,6 +32,17 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
const span = storage.getStore()?.span
this.setLLMObsTags({ request, span, response, modelProvider, modelName })
})

this.addSub('apm:aws:response:deserialize:bedrockruntime', ({ headers }) => {
const requestId = headers['x-amzn-requestid']
const inputTokenCount = headers['x-amzn-bedrock-input-token-count']
const outputTokenCount = headers['x-amzn-bedrock-output-token-count']

requestIdsToTokens[requestId] = {
inputTokensFromHeaders: inputTokenCount && parseInt(inputTokenCount),
outputTokensFromHeaders: outputTokenCount && parseInt(outputTokenCount)
}
})
}

setLLMObsTags ({ request, span, response, modelProvider, modelName }) {
Expand All @@ -52,7 +65,39 @@ class BedrockRuntimeLLMObsPlugin extends BaseLLMObsPlugin {
})

// add I/O tags
this._tagger.tagLLMIO(span, requestParams.prompt, textAndResponseReason.message)
this._tagger.tagLLMIO(
span,
requestParams.prompt,
[{ content: textAndResponseReason.message, role: textAndResponseReason.role }]
)

// add token metrics
const { inputTokens, outputTokens, totalTokens } = extractTokens({
requestId: response.$metadata.requestId,
usage: textAndResponseReason.usage
})
this._tagger.tagMetrics(span, {
inputTokens,
outputTokens,
totalTokens
})
}
}

function extractTokens ({ requestId, usage }) {
const {
inputTokensFromHeaders,
outputTokensFromHeaders
} = requestIdsToTokens[requestId] || {}
delete requestIdsToTokens[requestId]

const inputTokens = usage.inputTokens || inputTokensFromHeaders || 0
const outputTokens = usage.outputTokens || outputTokensFromHeaders || 0

return {
inputTokens,
outputTokens,
totalTokens: inputTokens + outputTokens
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const agent = require('../../../plugins/agent')

const nock = require('nock')
const { expectedLLMObsLLMSpanEvent, deepEqualWithMockValues, MOCK_ANY } = require('../../util')
const { expectedLLMObsLLMSpanEvent, deepEqualWithMockValues } = require('../../util')
const { models, modelConfig } = require('../../../../../datadog-plugin-aws-sdk/test/fixtures/bedrockruntime')
const chai = require('chai')
const LLMObsAgentProxySpanWriter = require('../../../../src/llmobs/writers/spans/agentProxy')
Expand Down Expand Up @@ -78,22 +78,31 @@ describe('Plugin', () => {

nock('http://127.0.0.1:4566')
.post(`/model/${model.modelId}/invoke`)
.reply(200, response)
.reply(200, response, {
'x-amzn-bedrock-input-token-count': 50,
'x-amzn-bedrock-output-token-count': 70,
'x-amzn-requestid': Date.now().toString()
})

const command = new AWS.InvokeModelCommand(request)

const expectedOutput = { content: model.output }
if (model.outputRole) expectedOutput.role = model.outputRole

agent.use(traces => {
const span = traces[0][0]
const spanEvent = LLMObsAgentProxySpanWriter.prototype.append.getCall(0).args[0]
const expected = expectedLLMObsLLMSpanEvent({
span,
spanKind: 'llm',
name: 'bedrock-runtime.command',
inputMessages: [
{ content: model.userPrompt }
],
outputMessages: MOCK_ANY,
tokenMetrics: { input_tokens: 0, output_tokens: 0, total_tokens: 0 },
inputMessages: [{ content: model.userPrompt }],
outputMessages: [expectedOutput],
tokenMetrics: {
input_tokens: model.usage?.inputTokens ?? 50,
output_tokens: model.usage?.outputTokens ?? 70,
total_tokens: model.usage?.totalTokens ?? 120
},
modelName: model.modelId.split('.')[1].toLowerCase(),
modelProvider: model.provider.toLowerCase(),
metadata: {
Expand Down

0 comments on commit f01d385

Please sign in to comment.