diff --git a/base/audit_types.go b/base/audit_types.go index 186274f0c6..548ccbaf6e 100644 --- a/base/audit_types.go +++ b/base/audit_types.go @@ -10,6 +10,7 @@ package base import ( "fmt" + "reflect" "strconv" ) @@ -73,7 +74,7 @@ const ( var fieldsByGroup = map[fieldGroup]map[string]any{ fieldGroupGlobal: { AuditFieldTimestamp: "timestamp", - AuditFieldID: 123, + AuditFieldID: uint32(123), AuditFieldName: "event name", AuditFieldDescription: "event description", }, @@ -150,7 +151,7 @@ func (ed *EventDescriptor) expandOptionalFieldGroups(groups []fieldGroup) { func (i AuditID) MustValidateFields(f AuditFields) { if err := i.ValidateFields(f); err != nil { - panic(fmt.Errorf("audit event %s(%s) invalid:\n%v", AuditEvents[i].Name, i, err)) + panic(fmt.Errorf("audit event %q (%s) invalid:\n%v", i, AuditEvents[i].Name, err)) } } @@ -168,15 +169,38 @@ func (i AuditID) ValidateFields(f AuditFields) error { func mandatoryFieldsPresent(fields, mandatoryFields AuditFields, baseName string) error { me := &MultiError{} for k, v := range mandatoryFields { + if _, ok := fields[k]; !ok { + me = me.Append(fmt.Errorf("missing mandatory field %s", baseName+k)) + continue + } + if !matchingTypes(v, fields[k]) { + me = me.Append(fmt.Errorf("field value for %s%s must be of type %T but had %T", baseName, k, v, fields[k])) + continue + } // recurse if map if vv, ok := v.(map[string]any); ok { if pv, ok := fields[k].(map[string]any); ok { me = me.Append(mandatoryFieldsPresent(pv, vv, baseName+k+".")) } } - if _, ok := fields[k]; !ok { - me = me.Append(fmt.Errorf("missing mandatory field %s", baseName+k)) - } } return me.ErrorOrNil() } + +// matchingTypes returns true if the types of a and b are the same. +func matchingTypes(a, b any) bool { + typeOfA, typeOfB := reflect.TypeOf(a), reflect.TypeOf(b) + if typeOfA == nil || typeOfB == nil { + return typeOfA == typeOfB + } + // deref + if typeOfA.Kind() == reflect.Pointer && typeOfB.Kind() != reflect.Pointer { + typeOfA = typeOfA.Elem() + } else if typeOfB.Kind() == reflect.Pointer && typeOfA.Kind() != reflect.Pointer { + typeOfB = typeOfB.Elem() + } + if typeOfA.ConvertibleTo(typeOfB) { + return true + } + return typeOfA.Kind() == typeOfB.Kind() +} diff --git a/base/logger_audit.go b/base/logger_audit.go index 7b1d0c4a07..bba1b6eeee 100644 --- a/base/logger_audit.go +++ b/base/logger_audit.go @@ -31,7 +31,7 @@ func expandFields(id AuditID, ctx context.Context, globalFields AuditFields, add } // static event data - fields[AuditFieldID] = uint64(id) + fields[AuditFieldID] = uint32(id) fields[AuditFieldName] = AuditEvents[id].Name fields[AuditFieldDescription] = AuditEvents[id].Description @@ -86,7 +86,7 @@ func expandFields(id AuditID, ctx context.Context, globalFields AuditFields, add } } - fields[AuditFieldTimestamp] = time.Now() + fields[AuditFieldTimestamp] = time.Now().Format(time.RFC3339) fields.merge(ctx, globalFields) fields.merge(ctx, logCtx.RequestAdditionalAuditFields) diff --git a/rest/admin_api.go b/rest/admin_api.go index 9e82342bb8..08dea34a3b 100644 --- a/rest/admin_api.go +++ b/rest/admin_api.go @@ -793,9 +793,14 @@ func (h *handler) handleGetDbAuditConfig() error { // PUT/POST audit config for database func (h *handler) handlePutDbAuditConfig() error { - var body HandleDbAuditConfigBody + var bodyRaw []byte err := h.mutateDbConfig(func(config *DbConfig) error { - if err := h.readJSONInto(&body); err != nil { + bodyRaw, err := h.readBody() + if err != nil { + return err + } + var body HandleDbAuditConfigBody + if err := base.JSONUnmarshal(bodyRaw, &body); err != nil { return err } @@ -860,7 +865,7 @@ func (h *handler) handlePutDbAuditConfig() error { } base.Audit(h.ctx(), base.AuditIDAuditConfigChanged, base.AuditFields{ base.AuditFieldAuditScope: "db", - base.AuditFieldPayload: body, + base.AuditFieldPayload: string(bodyRaw), }) return nil } @@ -2062,7 +2067,7 @@ func (h *handler) putReplication() error { if err != nil { return err } - auditFields := base.AuditFields{base.AuditFieldReplicationID: replicationConfig.ID, base.AuditFieldPayload: body} + auditFields := base.AuditFields{base.AuditFieldReplicationID: replicationConfig.ID, base.AuditFieldPayload: string(body)} if created { h.writeStatus(http.StatusCreated, "Created") base.Audit(h.ctx(), base.AuditIDISGRCreate, auditFields) diff --git a/rest/api.go b/rest/api.go index 338030f03c..3bf3b698d0 100644 --- a/rest/api.go +++ b/rest/api.go @@ -324,6 +324,8 @@ func (h *handler) handlePostResync() error { action := h.getQuery("action") regenerateSequences, _ := h.getOptBoolQuery("regenerate_sequences", false) + reset := h.getBoolQuery("reset") + body, err := h.readBody() if err != nil { return err @@ -353,7 +355,7 @@ func (h *handler) handlePostResync() error { "database": h.db, "regenerateSequences": regenerateSequences, "collections": resyncPostReqBody.Scope, - "reset": h.getBoolQuery("reset"), + "reset": reset, }) if err != nil { return err @@ -367,7 +369,7 @@ func (h *handler) handlePostResync() error { base.Audit(h.ctx(), base.AuditIDDatabaseResyncStart, base.AuditFields{ "collections": resyncPostReqBody.Scope, "regenerate_sequences": regenerateSequences, - "reset": h.getQuery("reset"), + "reset": reset, }) } else { dbState := atomic.LoadUint32(&h.db.State)