Skip to content

Commit

Permalink
interceptors: Update logging interceptor Reporter to re-extract field…
Browse files Browse the repository at this point in the history
…s from context before logging (#702)

When using logging.WithFieldsFromContext, if the value being extracted
as a log field is modified after the logging interceptor initializes the
Reporter before the underlying handler is called, then the updated value
will not be reflected in the log message.

To fix this, re-extract fields from the context before logging them in
PostCall, PostMsgSend and PostMsgReceive, ensuring the updated values in
the context are logged.

Signed-off-by: Chance Zibolski <chance.zibolski@gmail.com>
  • Loading branch information
chancez authored Apr 8, 2024
1 parent 3834477 commit ea545dc
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 5 deletions.
12 changes: 12 additions & 0 deletions interceptors/logging/interceptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func (c *reporter) PostCall(err error, duration time.Duration) {
if err != nil {
fields = fields.AppendUnique(Fields{"grpc.error", fmt.Sprintf("%v", err)})
}
if c.opts.fieldsFromCtxCallMetaFn != nil {
// fieldsFromCtxFn dups override the existing fields.
fields = c.opts.fieldsFromCtxCallMetaFn(c.ctx, c.CallMeta).AppendUnique(fields)
}
c.logger.Log(c.ctx, c.opts.levelFunc(code), "finished call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...)
}

Expand All @@ -50,6 +54,10 @@ func (c *reporter) PostMsgSend(payload any, err error, duration time.Duration) {
if err != nil {
fields = fields.AppendUnique(Fields{"grpc.error", fmt.Sprintf("%v", err)})
}
if c.opts.fieldsFromCtxCallMetaFn != nil {
// fieldsFromCtxFn dups override the existing fields.
fields = c.opts.fieldsFromCtxCallMetaFn(c.ctx, c.CallMeta).AppendUnique(fields)
}
if !c.startCallLogged && has(c.opts.loggableEvents, StartCall) {
c.startCallLogged = true
c.logger.Log(c.ctx, logLvl, "started call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...)
Expand Down Expand Up @@ -97,6 +105,10 @@ func (c *reporter) PostMsgReceive(payload any, err error, duration time.Duration
if err != nil {
fields = fields.AppendUnique(Fields{"grpc.error", fmt.Sprintf("%v", err)})
}
if c.opts.fieldsFromCtxCallMetaFn != nil {
// fieldsFromCtxFn dups override the existing fields.
fields = c.opts.fieldsFromCtxCallMetaFn(c.ctx, c.CallMeta).AppendUnique(fields)
}
if !c.startCallLogged && has(c.opts.loggableEvents, StartCall) {
c.startCallLogged = true
c.logger.Log(c.ctx, logLvl, "started call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...)
Expand Down
28 changes: 24 additions & 4 deletions interceptors/logging/interceptors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -172,9 +173,14 @@ type loggingClientServerSuite struct {
*baseLoggingSuite
}

func customFields(_ context.Context) logging.Fields {
func customFields(ctx context.Context) logging.Fields {
var val string
n := testpb.ExtractCtxTestNumber(ctx)
if n != nil {
val = strconv.Itoa(*n)
}
// Add custom fields. The second one overrides the first one.
return logging.Fields{"custom-field", "foo", "custom-field", "yolo"}
return logging.Fields{"custom-field", "foo", "custom-field", "yolo", "custom-ctx-field", val}
}

func TestSuite(t *testing.T) {
Expand Down Expand Up @@ -232,13 +238,17 @@ func (s *loggingClientServerSuite) TestPing() {
assert.Equal(s.T(), logging.LevelDebug, serverStartCallLogLine.lvl)
assert.Equal(s.T(), "started call", serverStartCallLogLine.msg)
_ = assertStandardFields(s.T(), logging.KindServerFieldValue, serverStartCallLogLine.fields, "Ping", interceptors.Unary)
// This field is zero initially, but will be updated by the service, which we should see after the call is finished
serverStartCallLogLine.fields.AssertField(s.T(), "custom-ctx-field", "0")

serverFinishCallLogLine := lines[2]
assert.Equal(s.T(), logging.LevelDebug, serverFinishCallLogLine.lvl)
assert.Equal(s.T(), "finished call", serverFinishCallLogLine.msg)
serverFinishCallFields := assertStandardFields(s.T(), logging.KindServerFieldValue, serverFinishCallLogLine.fields, "Ping", interceptors.Unary)
serverFinishCallFields.AssertFieldNotEmpty(s.T(), "peer.address").
AssertField(s.T(), "custom-field", "yolo").
// should be updated from 0 to 42
AssertField(s.T(), "custom-ctx-field", "42").
AssertFieldNotEmpty(s.T(), "grpc.start_time").
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
AssertField(s.T(), "grpc.code", "OK").
Expand All @@ -249,6 +259,8 @@ func (s *loggingClientServerSuite) TestPing() {
assert.Equal(s.T(), "finished call", clientFinishCallLogLine.msg)
clientFinishCallFields := assertStandardFields(s.T(), logging.KindClientFieldValue, clientFinishCallLogLine.fields, "Ping", interceptors.Unary)
clientFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
// should be updated from 0 to 42
AssertField(s.T(), "custom-ctx-field", "42").
AssertField(s.T(), "grpc.request.value", "something").
AssertFieldNotEmpty(s.T(), "grpc.start_time").
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
Expand Down Expand Up @@ -285,6 +297,8 @@ func (s *loggingClientServerSuite) TestPingList() {
assert.Equal(s.T(), "finished call", serverFinishCallLogLine.msg)
serverFinishCallFields := assertStandardFields(s.T(), logging.KindServerFieldValue, serverFinishCallLogLine.fields, "PingList", interceptors.ServerStream)
serverFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
// should be updated from 0 to 42
AssertField(s.T(), "custom-ctx-field", "42").
AssertFieldNotEmpty(s.T(), "peer.address").
AssertFieldNotEmpty(s.T(), "grpc.start_time").
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
Expand All @@ -297,6 +311,8 @@ func (s *loggingClientServerSuite) TestPingList() {
clientFinishCallFields := assertStandardFields(s.T(), logging.KindClientFieldValue, clientFinishCallLogLine.fields, "PingList", interceptors.ServerStream)
clientFinishCallFields.AssertFieldNotEmpty(s.T(), "grpc.start_time").
AssertField(s.T(), "custom-field", "yolo").
// should be updated from 0 to 42
AssertField(s.T(), "custom-ctx-field", "42").
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
AssertField(s.T(), "grpc.code", "OK").
AssertFieldNotEmpty(s.T(), "grpc.time_ms").AssertNoMoreTags(s.T())
Expand Down Expand Up @@ -344,23 +360,27 @@ func (s *loggingClientServerSuite) TestPingError_WithCustomLevels() {
assert.Equal(t, "finished call", serverFinishCallLogLine.msg)
serverFinishCallFields := assertStandardFields(t, logging.KindServerFieldValue, serverFinishCallLogLine.fields, "PingError", interceptors.Unary)
serverFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
// should be updated from 0 to 42
AssertField(s.T(), "custom-ctx-field", "42").
AssertFieldNotEmpty(t, "peer.address").
AssertFieldNotEmpty(t, "grpc.start_time").
AssertFieldNotEmpty(t, "grpc.request.deadline").
AssertField(t, "grpc.code", tcase.code.String()).
AssertField(t, "grpc.error", fmt.Sprintf("rpc error: code = %s desc = Userspace error", tcase.code.String())).
AssertFieldNotEmpty(t, "grpc.time_ms").AssertNoMoreTags(t)
AssertFieldNotEmpty(s.T(), "grpc.time_ms").AssertNoMoreTags(s.T())

clientFinishCallLogLine := lines[0]
assert.Equal(t, tcase.level, clientFinishCallLogLine.lvl)
assert.Equal(t, "finished call", clientFinishCallLogLine.msg)
clientFinishCallFields := assertStandardFields(t, logging.KindClientFieldValue, clientFinishCallLogLine.fields, "PingError", interceptors.Unary)
clientFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
// should be updated from 0 to 42
AssertField(s.T(), "custom-ctx-field", "42").
AssertFieldNotEmpty(t, "grpc.start_time").
AssertFieldNotEmpty(t, "grpc.request.deadline").
AssertField(t, "grpc.code", tcase.code.String()).
AssertField(t, "grpc.error", fmt.Sprintf("rpc error: code = %s desc = Userspace error", tcase.code.String())).
AssertFieldNotEmpty(t, "grpc.time_ms").AssertNoMoreTags(t)
AssertFieldNotEmpty(s.T(), "grpc.time_ms").AssertNoMoreTags(s.T())
})
}
}
Expand Down
25 changes: 25 additions & 0 deletions testing/testpb/interceptor_suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,33 @@ func (s *InterceptorTestSuite) ServerAddr() string {
return s.serverAddr
}

type ctxTestNumber struct{}

var (
ctxTestNumberKey = &ctxTestNumber{}
zero = 0
)

func ExtractCtxTestNumber(ctx context.Context) *int {
if v, ok := ctx.Value(ctxTestNumberKey).(*int); ok {
return v
}
return &zero
}

// UnaryServerInterceptor returns a new unary server interceptors that adds query information logging.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
// newCtx := newContext(ctx, log, opts)
newCtx := ctx
resp, err := handler(newCtx, req)
return resp, err
}
}

func (s *InterceptorTestSuite) SimpleCtx() context.Context {
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
ctx = context.WithValue(ctx, ctxTestNumberKey, 1)
s.cancels = append(s.cancels, cancel)
return ctx
}
Expand Down
7 changes: 6 additions & 1 deletion testing/testpb/pingservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ func (s *TestPingService) PingEmpty(_ context.Context, _ *PingEmptyRequest) (*Pi
return &PingEmptyResponse{}, nil
}

func (s *TestPingService) Ping(_ context.Context, ping *PingRequest) (*PingResponse, error) {
func (s *TestPingService) Ping(ctx context.Context, ping *PingRequest) (*PingResponse, error) {
// Modify the ctx value to verify the logger sees the value updated from the initial value
n := ExtractCtxTestNumber(ctx)
if n != nil {
*n = 42
}
// Send user trailers and headers.
return &PingResponse{Value: ping.Value, Counter: 0}, nil
}
Expand Down

0 comments on commit ea545dc

Please sign in to comment.