Skip to content

Commit

Permalink
feat: Added new API method withLlmCustomAttributes to run a functio…
Browse files Browse the repository at this point in the history
…n in a LLM context (#2437)

The context will be used to assign custom attributes to every LLM event produced within the function
  • Loading branch information
MikeVaz authored Aug 22, 2024
1 parent 0448927 commit 57e6be9
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 6 deletions.
49 changes: 49 additions & 0 deletions api.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const obfuscate = require('./lib/util/sql/obfuscate')
const { DESTINATIONS } = require('./lib/config/attribute-filter')
const parse = require('module-details-from-path')
const { isSimpleObject } = require('./lib/util/objects')
const { AsyncLocalStorage } = require('async_hooks')

/*
*
Expand Down Expand Up @@ -1902,4 +1903,52 @@ API.prototype.ignoreApdex = function ignoreApdex() {
transaction.ignoreApdex = true
}

/**

Check warning on line 1906 in api.js

View workflow job for this annotation

GitHub Actions / lint (lts/*)

Missing JSDoc @returns declaration
* Run a function with the passed in LLM context as the active context and return its return value.
*
* An example of setting a custom attribute:
*
* newrelic.withLlmCustomAttributes({'llm.someAttribute': 'someValue'}, () => {
* return;
* })
* @param {Object} context LLM custom attributes context
* @param {Function} callback The function to execute in context.
*/
API.prototype.withLlmCustomAttributes = function withLlmCustomAttributes(context, callback) {
context = context || {}
const metric = this.agent.metrics.getOrCreateMetric(
NAMES.SUPPORTABILITY.API + '/withLlmCustomAttributes'
)
metric.incrementCallCount()

const transaction = this.agent.tracer.getTransaction()

if (!callback || typeof callback !== 'function') {
logger.warn('withLlmCustomAttributes must be used with a valid callback')
return
}

if (!transaction) {
logger.warn('withLlmCustomAttributes must be called within the scope of a transaction.')
return callback()
}

for (const [key, value] of Object.entries(context)) {
if (typeof value === 'object' || typeof value === 'function') {
logger.warn(`Invalid attribute type for ${key}. Skipped.`)
delete context[key]
} else if (key.indexOf('llm.') !== 0) {
logger.warn(`Invalid attribute name ${key}. Renamed to "llm.${key}".`)
delete context[key]
context[`llm.${key}`] = value
}
}

transaction._llmContextManager = transaction._llmContextManager || new AsyncLocalStorage()
const parentContext = transaction._llmContextManager.getStore() || {}

const fullContext = Object.assign({}, parentContext, context)
return transaction._llmContextManager.run(fullContext, callback)
}

module.exports = API
8 changes: 7 additions & 1 deletion lib/instrumentation/aws-sdk/v3/bedrock.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const { DESTINATIONS } = require('../../../config/attribute-filter')
const { AI } = require('../../../metrics/names')
const { RecorderSpec } = require('../../../shim/specs')
const InstrumentationDescriptor = require('../../../instrumentation-descriptor')
const { extractLlmContext } = require('../../../util/llm-utils')

let TRACKING_METRIC

Expand Down Expand Up @@ -55,7 +56,12 @@ function isStreamingEnabled({ commandName, config }) {
*/
function recordEvent({ agent, type, msg }) {
msg.serialize()
agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg])
const llmContext = extractLlmContext(agent)

agent.customEventAggregator.add([
{ type, timestamp: Date.now() },
Object.assign({}, msg, llmContext)
])
}

/**
Expand Down
8 changes: 7 additions & 1 deletion lib/instrumentation/langchain/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
const {
AI: { LANGCHAIN }
} = require('../../metrics/names')
const { extractLlmContext } = require('../../util/llm-utils')

const common = module.exports

Expand Down Expand Up @@ -49,7 +50,12 @@ common.mergeMetadata = function mergeMetadata(localMeta = {}, paramsMeta = {}) {
*/
common.recordEvent = function recordEvent({ agent, type, msg, pkgVersion }) {
agent.metrics.getOrCreateMetric(`${LANGCHAIN.TRACKING_PREFIX}/${pkgVersion}`).incrementCallCount()
agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg])
const llmContext = extractLlmContext(agent)

agent.customEventAggregator.add([
{ type, timestamp: Date.now() },
Object.assign({}, msg, llmContext)
])
}

/**
Expand Down
8 changes: 7 additions & 1 deletion lib/instrumentation/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const {
LlmErrorMessage
} = require('../../lib/llm-events/openai')
const { RecorderSpec } = require('../../lib/shim/specs')
const { extractLlmContext } = require('../util/llm-utils')

const MIN_VERSION = '4.0.0'
const MIN_STREAM_VERSION = '4.12.2'
Expand Down Expand Up @@ -75,7 +76,12 @@ function decorateSegment({ shim, result, apiKey }) {
* @param {object} params.msg LLM event
*/
function recordEvent({ agent, type, msg }) {
agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg])
const llmContext = extractLlmContext(agent)

agent.customEventAggregator.add([
{ type, timestamp: Date.now() },
Object.assign({}, msg, llmContext)
])
}

/**
Expand Down
34 changes: 34 additions & 0 deletions lib/util/llm-utils.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2020 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

exports = module.exports = { extractLlmContext, extractLlmAttributes }

/**
* Extract LLM attributes from the LLM context
*
* @param {Object} context LLM context object
* @returns {Object} LLM custom attributes
*/
function extractLlmAttributes(context) {
return Object.keys(context).reduce((result, key) => {
if (key.indexOf('llm.') === 0) {
result[key] = context[key]
}
return result
}, {})
}

/**
* Extract LLM context from the active transaction
*
* @param {Agent} agent NR agent instance
* @returns {Object} LLM context object
*/
function extractLlmContext(agent) {
const context = agent.tracer.getTransaction()?._llmContextManager?.getStore() || {}
return extractLlmAttributes(context)
}
110 changes: 110 additions & 0 deletions test/unit/api/api-llm.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,116 @@ tap.test('Agent API LLM methods', (t) => {
})
})

t.test('withLlmCustomAttributes should handle no active transaction', (t) => {
const { api } = t.context
t.equal(
api.withLlmCustomAttributes({ test: 1 }, () => {
t.equal(loggerMock.warn.callCount, 1)
return 1
}),
1
)
t.end()
})

t.test('withLlmCustomAttributes should handle an empty store', (t) => {
const { api } = t.context
const agent = api.agent

helper.runInTransaction(api.agent, (tx) => {
agent.tracer.getTransaction = () => {
return tx
}
t.equal(
api.withLlmCustomAttributes(null, () => {
return 1
}),
1
)
t.end()
})
})

t.test('withLlmCustomAttributes should handle no callback', (t) => {
const { api } = t.context
const agent = api.agent
helper.runInTransaction(api.agent, (tx) => {
agent.tracer.getTransaction = () => {
return tx
}
api.withLlmCustomAttributes({ test: 1 }, null)
t.equal(loggerMock.warn.callCount, 1)
t.end()
})
})

t.test('withLlmCustomAttributes should normalize attributes', (t) => {
const { api } = t.context
const agent = api.agent
helper.runInTransaction(api.agent, (tx) => {
agent.tracer.getTransaction = () => {
return tx
}
api.withLlmCustomAttributes(
{
'toRename': 'value1',
'llm.number': 1,
'llm.boolean': true,
'toDelete': () => {},
'toDelete2': {},
'toDelete3': []
},
() => {
const contextManager = tx._llmContextManager
const parentContext = contextManager.getStore()
t.equal(parentContext['llm.toRename'], 'value1')
t.notOk(parentContext.toDelete)
t.notOk(parentContext.toDelete2)
t.notOk(parentContext.toDelete3)
t.equal(parentContext['llm.number'], 1)
t.equal(parentContext['llm.boolean'], true)
t.end()
}
)
})
})

t.test('withLlmCustomAttributes should support branching', (t) => {
const { api } = t.context
const agent = api.agent
t.autoend()
helper.runInTransaction(api.agent, (tx) => {
agent.tracer.getTransaction = () => {
return tx
}
api.withLlmCustomAttributes(
{ 'llm.step': '1', 'llm.path': 'root', 'llm.name': 'root' },
() => {
const contextManager = tx._llmContextManager
const context = contextManager.getStore()
t.equal(context[`llm.step`], '1')
t.equal(context['llm.path'], 'root')
t.equal(context['llm.name'], 'root')
api.withLlmCustomAttributes({ 'llm.step': '1.1', 'llm.path': 'root/1' }, () => {
const contextManager = tx._llmContextManager
const context = contextManager.getStore()
t.equal(context[`llm.step`], '1.1')
t.equal(context['llm.path'], 'root/1')
t.equal(context['llm.name'], 'root')
})
api.withLlmCustomAttributes({ 'llm.step': '1.2', 'llm.path': 'root/2' }, () => {
const contextManager = tx._llmContextManager
const context = contextManager.getStore()
t.equal(context[`llm.step`], '1.2')
t.equal(context['llm.path'], 'root/2')
t.equal(context['llm.name'], 'root')
t.end()
})
}
)
})
})

t.test('setLlmTokenCount should register callback to calculate token counts', async (t) => {
const { api, agent } = t.context
function callback(model, content) {
Expand Down
2 changes: 1 addition & 1 deletion test/unit/api/stub.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
const tap = require('tap')
const API = require('../../../stub_api')

const EXPECTED_API_COUNT = 36
const EXPECTED_API_COUNT = 37

tap.test('Agent API - Stubbed Agent API', (t) => {
t.autoend()
Expand Down
1 change: 1 addition & 0 deletions test/unit/instrumentation/openai.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,6 @@ test('openai unit tests', (t) => {
t.equal(isWrapped, false, 'should not wrap chat completions create')
t.end()
})

t.end()
})
74 changes: 74 additions & 0 deletions test/unit/util/llm-utils.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2023 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

const tap = require('tap')
const { extractLlmAttributes, extractLlmContext } = require('../../../lib/util/llm-utils')
const { AsyncLocalStorage } = require('async_hooks')

tap.test('extractLlmAttributes', (t) => {
const context = {
'skip': 1,
'llm.get': 2,
'fllm.skip': 3
}

const llmContext = extractLlmAttributes(context)
t.notOk(llmContext.skip)
t.notOk(llmContext['fllm.skip'])
t.equal(llmContext['llm.get'], 2)
t.end()
})

tap.test('extractLlmContext', (t) => {
t.beforeEach((t) => {
const tx = {
_llmContextManager: new AsyncLocalStorage()
}
t.context.agent = {
tracer: {
getTransaction: () => {
return tx
}
}
}
t.context.tx = tx
})

t.test('handle empty context', (t) => {
const { tx, agent } = t.context
tx._llmContextManager.run(null, () => {
const llmContext = extractLlmContext(agent)
t.equal(typeof llmContext, 'object')
t.equal(Object.entries(llmContext).length, 0)
t.end()
})
})

t.test('extract LLM context', (t) => {
const { tx, agent } = t.context
tx._llmContextManager.run({ 'llm.test': 1, 'skip': 2 }, () => {
const llmContext = extractLlmContext(agent)
t.equal(llmContext['llm.test'], 1)
t.notOk(llmContext.skip)
t.end()
})
})

t.test('no transaction', (t) => {
const { tx, agent } = t.context
agent.tracer.getTransaction = () => {
return null
}
tx._llmContextManager.run(null, () => {
const llmContext = extractLlmContext(agent)
t.equal(typeof llmContext, 'object')
t.equal(Object.entries(llmContext).length, 0)
t.end()
})
})
t.end()
})
Loading

0 comments on commit 57e6be9

Please sign in to comment.