Skip to content

Commit

Permalink
Merge pull request #133 from itaischwartz/feature/contextfunc
Browse files Browse the repository at this point in the history
feat: add context function
  • Loading branch information
moul authored Dec 29, 2022
2 parents 8fd1680 + 6e85b09 commit ce4cb09
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
1 change: 1 addition & 0 deletions go.mod

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 22 additions & 9 deletions zapgorm2.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@ import (
"time"

"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)

type ContextFn func(ctx context.Context) []zapcore.Field

type Logger struct {
ZapLogger *zap.Logger
LogLevel gormlogger.LogLevel
SlowThreshold time.Duration
SkipCallerLookup bool
IgnoreRecordNotFoundError bool
Context ContextFn
}

func New(zapLogger *zap.Logger) Logger {
Expand All @@ -28,6 +32,7 @@ func New(zapLogger *zap.Logger) Logger {
SlowThreshold: 100 * time.Millisecond,
SkipCallerLookup: false,
IgnoreRecordNotFoundError: false,
Context: nil,
}
}

Expand All @@ -42,45 +47,47 @@ func (l Logger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
LogLevel: level,
SkipCallerLookup: l.SkipCallerLookup,
IgnoreRecordNotFoundError: l.IgnoreRecordNotFoundError,
Context: l.Context,
}
}

func (l Logger) Info(ctx context.Context, str string, args ...interface{}) {
if l.LogLevel < gormlogger.Info {
return
}
l.logger().Sugar().Debugf(str, args...)
l.logger(ctx).Sugar().Debugf(str, args...)
}

func (l Logger) Warn(ctx context.Context, str string, args ...interface{}) {
if l.LogLevel < gormlogger.Warn {
return
}
l.logger().Sugar().Warnf(str, args...)
l.logger(ctx).Sugar().Warnf(str, args...)
}

func (l Logger) Error(ctx context.Context, str string, args ...interface{}) {
if l.LogLevel < gormlogger.Error {
return
}
l.logger().Sugar().Errorf(str, args...)
l.logger(ctx).Sugar().Errorf(str, args...)
}

func (l Logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= 0 {
return
}
elapsed := time.Since(begin)
logger := l.logger(ctx)
switch {
case err != nil && l.LogLevel >= gormlogger.Error && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
sql, rows := fc()
l.logger().Error("trace", zap.Error(err), zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql))
logger.Error("trace", zap.Error(err), zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql))
case l.SlowThreshold != 0 && elapsed > l.SlowThreshold && l.LogLevel >= gormlogger.Warn:
sql, rows := fc()
l.logger().Warn("trace", zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql))
logger.Warn("trace", zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql))
case l.LogLevel >= gormlogger.Info:
sql, rows := fc()
l.logger().Debug("trace", zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql))
logger.Debug("trace", zap.Duration("elapsed", elapsed), zap.Int64("rows", rows), zap.String("sql", sql))
}
}

Expand All @@ -89,7 +96,13 @@ var (
zapgormPackage = filepath.Join("moul.io", "zapgorm2")
)

func (l Logger) logger() *zap.Logger {
func (l Logger) logger(ctx context.Context) *zap.Logger {
logger := l.ZapLogger
if l.Context != nil {
fields := l.Context(ctx)
logger = logger.With(fields...)
}

for i := 2; i < 15; i++ {
_, file, _, ok := runtime.Caller(i)
switch {
Expand All @@ -98,8 +111,8 @@ func (l Logger) logger() *zap.Logger {
case strings.Contains(file, gormPackage):
case strings.Contains(file, zapgormPackage):
default:
return l.ZapLogger.WithOptions(zap.AddCallerSkip(i))
return logger.WithOptions(zap.AddCallerSkip(i))
}
}
return l.ZapLogger
return logger
}
45 changes: 45 additions & 0 deletions zapgorm2_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package zapgorm2_test

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
"gorm.io/gorm"
"moul.io/zapgorm2"
)
Expand All @@ -14,3 +20,42 @@ func Example() {
// do stuff normally
var _ = db // avoid "unused variable" warn
}

func setupLogsCapture() (*zap.Logger, *observer.ObservedLogs) {
core, logs := observer.New(zap.WarnLevel)
return zap.New(core), logs
}

func TestContextFunc(t *testing.T) {
zaplogger, logs := setupLogsCapture()
logger := zapgorm2.New(zaplogger)

type ctxKey string
key1 := ctxKey("Key")
key2 := ctxKey("Key2")

value1 := "Value"
value2 := "Value2"

ctx := context.WithValue(context.Background(), key1, value1)
ctx = context.WithValue(ctx, key2, value2)
logger.Context = func(ctx context.Context) []zapcore.Field {
ctxValue, ok := (ctx.Value(key1)).(string)
require.True(t, ok)
ctxValue2, ok := (ctx.Value(key2)).(string)
require.True(t, ok)
return []zapcore.Field{zap.String(string(key1), ctxValue), zap.String(string(key2), ctxValue2)}
}

db, err := gorm.Open(nil, &gorm.Config{Logger: logger})
require.NoError(t, err)

db.Logger.Error(ctx, "test")
require.Equal(t, 1, logs.Len())
entry := logs.All()[0]
require.Equal(t, zap.ErrorLevel, entry.Level)
require.Equal(t, "test", entry.Message)
require.Equal(t, value1, entry.ContextMap()[string(key1)])
require.Equal(t, value2, entry.ContextMap()[string(key2)])

}

0 comments on commit ce4cb09

Please sign in to comment.