Skip to content

Commit

Permalink
prevent all transaction methods from nil transaction errors (#1001)
Browse files Browse the repository at this point in the history
* prevent all transaction methods from executing against nil transaction object

* fix: incorrect nil logic for segments
  • Loading branch information
iamemilio authored Feb 24, 2025
1 parent 18613dc commit d55d218
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 26 deletions.
1 change: 0 additions & 1 deletion v3/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ require (
google.golang.org/protobuf v1.34.2
)


retract v3.22.0 // release process error corrected in v3.22.1

retract v3.25.0 // release process error corrected in v3.25.1
Expand Down
67 changes: 42 additions & 25 deletions v3/newrelic/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ type Transaction struct {
thread *thread
}

// nilTransaction guards against nil errors when handling a transaction.
func nilTransaction(txn *Transaction) bool {
return txn == nil || txn.thread == nil || txn.thread.txn == nil
}

// End finishes the Transaction. After that, subsequent calls to End or
// other Transaction methods have no effect. All segments and
// instrumentation must be completed before End is called.
func (txn *Transaction) End() {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}

Expand All @@ -55,15 +60,15 @@ func (txn *Transaction) End() {
// The set of options should be the complete set you wish to have in effect,
// just as if you were calling StartTransaction now with the same set of options.
func (txn *Transaction) SetOption(options ...TraceOption) {
if txn == nil || txn.thread == nil || txn.thread.txn == nil {
if nilTransaction(txn) {
return
}
txn.thread.txn.setOption(options...)
}

// Ignore prevents this transaction's data from being recorded.
func (txn *Transaction) Ignore() {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.Ignore(), "ignore transaction", nil)
Expand All @@ -72,7 +77,7 @@ func (txn *Transaction) Ignore() {
// SetName names the transaction. Use a limited set of unique names to
// ensure that Transactions are grouped usefully.
func (txn *Transaction) SetName(name string) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.SetName(name), "set transaction name", nil)
Expand All @@ -84,8 +89,7 @@ func (txn *Transaction) Name() string {
// This is called Name rather than GetName to be consistent with the prevailing naming
// conventions for the Go language, even though the underlying internal call must be called
// something else (like GetName) because there's already a Name struct member.

if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return ""
}
return txn.thread.GetName()
Expand Down Expand Up @@ -117,7 +121,7 @@ func (txn *Transaction) Name() string {
// way to directly control the recorded error's message, class, stacktrace,
// and attributes.
func (txn *Transaction) NoticeError(err error) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.NoticeError(err, false), "notice error", nil)
Expand Down Expand Up @@ -151,7 +155,7 @@ func (txn *Transaction) NoticeError(err error) {
// way to directly control the recorded error's message, class, stacktrace,
// and attributes.
func (txn *Transaction) NoticeExpectedError(err error) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.NoticeError(err, true), "notice error", nil)
Expand All @@ -166,7 +170,7 @@ func (txn *Transaction) NoticeExpectedError(err error) {
// For more information, see:
// https://docs.newrelic.com/docs/agents/manage-apm-agents/agent-metrics/collect-custom-attributes
func (txn *Transaction) AddAttribute(key string, value any) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.AddAttribute(key, value), "add attribute", nil)
Expand All @@ -176,10 +180,9 @@ func (txn *Transaction) AddAttribute(key string, value any) {
// belong to or interact with. This will propogate an attribute containing this information to all events that are
// a child of this transaction, like errors and spans.
func (txn *Transaction) SetUserID(userID string) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}

txn.thread.logAPIError(txn.thread.AddUserID(userID), "set user ID", nil)
}

Expand All @@ -192,6 +195,9 @@ func (txn *Transaction) SetUserID(userID string) {
// as well as log metrics depending on how your application is
// configured.
func (txn *Transaction) RecordLog(log LogData) {
if nilTransaction(txn) {
return
}
event, err := log.toLogEvent()
if err != nil {
txn.Application().app.Error("unable to record log", map[string]any{
Expand All @@ -212,6 +218,9 @@ func (txn *Transaction) RecordLog(log LogData) {
// present, the agent will look for distributed tracing headers using
// Transaction.AcceptDistributedTraceHeaders.
func (txn *Transaction) SetWebRequestHTTP(r *http.Request) {
if nilTransaction(txn) {
return
}
if r == nil {
txn.SetWebRequest(WebRequest{})
return
Expand Down Expand Up @@ -265,7 +274,7 @@ func reqBody(req *http.Request) *BodyBuffer {
// distributed tracing headers using Transaction.AcceptDistributedTraceHeaders.
// Use Transaction.SetWebRequestHTTP if you have a *http.Request.
func (txn *Transaction) SetWebRequest(r WebRequest) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
if IsSecurityAgentPresent() {
Expand All @@ -289,7 +298,7 @@ func (txn *Transaction) SetWebRequest(r WebRequest) {
// package middlewares. Therefore, you probably want to use this only if you
// are writing your own instrumentation middleware.
func (txn *Transaction) SetWebResponse(w http.ResponseWriter) http.ResponseWriter {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return w
}
return txn.thread.SetWebResponse(w)
Expand All @@ -304,7 +313,7 @@ func (txn *Transaction) StartSegmentNow() SegmentStartTime {
}

func (txn *Transaction) startSegmentAt(at time.Time) SegmentStartTime {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return SegmentStartTime{}
}
return txn.thread.startSegmentAt(at)
Expand All @@ -324,7 +333,11 @@ func (txn *Transaction) startSegmentAt(at time.Time) SegmentStartTime {
// // ... code you want to time here ...
// segment.End()
func (txn *Transaction) StartSegment(name string) *Segment {
if IsSecurityAgentPresent() && txn != nil && txn.thread != nil && txn.thread.thread != nil && txn.thread.thread.threadID > 0 {
if nilTransaction(txn) {
return &Segment{} // return a non-nil Segment to avoid nil dereference
}

if IsSecurityAgentPresent() && txn.thread.thread != nil && txn.thread.thread.threadID > 0 {
// async segment start
secureAgent.SendEvent("NEW_GOROUTINE_LINKER", txn.thread.getCsecData())
}
Expand All @@ -346,7 +359,7 @@ func (txn *Transaction) StartSegment(name string) *Segment {
// StartExternalSegment calls InsertDistributedTraceHeaders, so you don't need
// to use it for outbound HTTP calls: Just use StartExternalSegment!
func (txn *Transaction) InsertDistributedTraceHeaders(hdrs http.Header) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.CreateDistributedTracePayload(hdrs)
Expand All @@ -367,7 +380,7 @@ func (txn *Transaction) InsertDistributedTraceHeaders(hdrs http.Header) {
// context headers. Only when those are not found will it look for the New
// Relic distributed tracing header.
func (txn *Transaction) AcceptDistributedTraceHeaders(t TransportType, hdrs http.Header) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.AcceptDistributedTraceHeaders(t, hdrs), "accept trace payload", nil)
Expand All @@ -379,6 +392,10 @@ func (txn *Transaction) AcceptDistributedTraceHeaders(t TransportType, hdrs http
// convert the JSON string to http headers. There is no guarantee that the header data found in JSON
// is correct beyond conforming to the expected types and syntax.
func (txn *Transaction) AcceptDistributedTraceHeadersFromJSON(t TransportType, jsondata string) error {
if nilTransaction(txn) { // do no work if txn is nil
return nil
}

hdrs, err := DistributedTraceHeadersFromJSON(jsondata)
if err != nil {
return err
Expand Down Expand Up @@ -465,7 +482,7 @@ func DistributedTraceHeadersFromJSON(jsondata string) (hdrs http.Header, err err

// Application returns the Application which started the transaction.
func (txn *Transaction) Application() *Application {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
return txn.thread.Application()
Expand All @@ -484,7 +501,7 @@ func (txn *Transaction) Application() *Application {
// monitoring is disabled, the application is not connected, or an error
// occurred. It is safe to call the pointer's methods if it is nil.
func (txn *Transaction) BrowserTimingHeader() *BrowserTimingHeader {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
b, err := txn.thread.BrowserTimingHeader()
Expand All @@ -506,7 +523,7 @@ func (txn *Transaction) BrowserTimingHeader() *BrowserTimingHeader {
// Note that any segments that end after the transaction ends will not
// be reported.
func (txn *Transaction) NewGoroutine() *Transaction {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
newTxn := txn.thread.NewGoroutine()
Expand All @@ -519,7 +536,7 @@ func (txn *Transaction) NewGoroutine() *Transaction {
// GetTraceMetadata returns distributed tracing identifiers. Empty
// string identifiers are returned if the transaction has finished.
func (txn *Transaction) GetTraceMetadata() TraceMetadata {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return TraceMetadata{}
}
return txn.thread.GetTraceMetadata()
Expand All @@ -528,7 +545,7 @@ func (txn *Transaction) GetTraceMetadata() TraceMetadata {
// GetLinkingMetadata returns the fields needed to link data to a trace or
// entity.
func (txn *Transaction) GetLinkingMetadata() LinkingMetadata {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return LinkingMetadata{}
}
return txn.thread.GetLinkingMetadata()
Expand All @@ -539,21 +556,21 @@ func (txn *Transaction) GetLinkingMetadata() LinkingMetadata {
// must be enabled for transactions to be sampled. False is returned if
// the Transaction has finished.
func (txn *Transaction) IsSampled() bool {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return false
}
return txn.thread.IsSampled()
}

func (txn *Transaction) GetCsecAttributes() map[string]any {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
return txn.thread.getCsecAttributes()
}

func (txn *Transaction) SetCsecAttributes(key string, value any) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.setCsecAttributes(key, value)
Expand Down
72 changes: 72 additions & 0 deletions v3/newrelic/transaction_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package newrelic

import (
"fmt"
"net/http"
"testing"
)

func TestTransaction_MethodsWithNilTransaction(t *testing.T) {
var nilTxn *Transaction

defer func() {
if r := recover(); r != nil {
t.Errorf("panics should not occur on methods of Transaction: %v", r)
}
}()

// Ensure no panic occurs when calling methods on a nil transaction
nilTxn.End()
nilTxn.SetOption()
nilTxn.Ignore()
nilTxn.SetName("test")
name := nilTxn.Name()
if name != "" {
t.Errorf("expected empty string, got %s", name)
}
nilTxn.NoticeError(fmt.Errorf("test error"))
nilTxn.NoticeExpectedError(fmt.Errorf("test expected error"))
nilTxn.AddAttribute("key", "value")
nilTxn.SetUserID("user123")
nilTxn.RecordLog(LogData{})
nilTxn.SetWebRequestHTTP(nil)
nilTxn.SetWebRequest(WebRequest{})
nilTxn.SetWebResponse(nil)
nilTxn.StartSegmentNow()
nilTxn.StartSegment("test segment")
nilTxn.InsertDistributedTraceHeaders(http.Header{})
nilTxn.AcceptDistributedTraceHeaders(TransportHTTP, http.Header{})
err := nilTxn.AcceptDistributedTraceHeadersFromJSON(TransportHTTP, "{}")
if err != nil {
t.Errorf("expected no error, got %v", err)
}
app := nilTxn.Application()
if app != nil {
t.Errorf("expected nil, got %v", app)
}
bth := nilTxn.BrowserTimingHeader()
if bth != nil {
t.Errorf("expected nil, got %v", bth)
}
newTxn := nilTxn.NewGoroutine()
if newTxn != nil {
t.Errorf("expected nil, got %v", newTxn)
}
traceMetadata := nilTxn.GetTraceMetadata()
if traceMetadata != (TraceMetadata{}) {
t.Errorf("expected empty TraceMetadata, got %v", traceMetadata)
}
linkingMetadata := nilTxn.GetLinkingMetadata()
if linkingMetadata != (LinkingMetadata{}) {
t.Errorf("expected empty LinkingMetadata, got %v", linkingMetadata)
}
isSampled := nilTxn.IsSampled()
if isSampled {
t.Errorf("expected false, got %v", isSampled)
}
csecAttributes := nilTxn.GetCsecAttributes()
if csecAttributes != nil {
t.Errorf("expected nil, got %v", csecAttributes)
}
nilTxn.SetCsecAttributes("key", "value")
}

0 comments on commit d55d218

Please sign in to comment.