From eeec0182fd2021fac8da3836204fc5a01e754b0c Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Fri, 21 Jan 2022 14:17:18 -0700 Subject: [PATCH] Use API-level proto-to-id helper for audit fields A few call sites were missed during the the recent change to introduce back-compat on "ensure leading slash" behavior where the audit fields are being populated. This could cause an API that was accepted to omit the SPIFFE ID or parent ID fields from the audit log. This change fixes those call sites to use the proper helpers from the API. The IDFromProto method is exported used instead of the more stringent TrustDomain*FromProto methods since audit logging doesn't need that extra validation. Signed-off-by: Andrew Harding --- pkg/server/api/entry/v1/service.go | 11 +++++------ pkg/server/api/id.go | 9 +++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/server/api/entry/v1/service.go b/pkg/server/api/entry/v1/service.go index ed32f8259c..3b965db371 100644 --- a/pkg/server/api/entry/v1/service.go +++ b/pkg/server/api/entry/v1/service.go @@ -9,7 +9,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" - "github.com/spiffe/spire/pkg/common/idutil" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/server/api" "github.com/spiffe/spire/pkg/server/api/rpccontext" @@ -186,7 +185,7 @@ func (s *Service) BatchCreateEntry(ctx context.Context, req *entryv1.BatchCreate r := s.createEntry(ctx, eachEntry, req.OutputMask) results = append(results, r) rpccontext.AuditRPCWithTypesStatus(ctx, r.Status, func() logrus.Fields { - return fieldsFromEntryProto(eachEntry, nil) + return fieldsFromEntryProto(ctx, eachEntry, nil) }) } @@ -241,7 +240,7 @@ func (s *Service) BatchUpdateEntry(ctx context.Context, req *entryv1.BatchUpdate e := s.updateEntry(ctx, eachEntry, req.InputMask, req.OutputMask) results = append(results, e) rpccontext.AuditRPCWithTypesStatus(ctx, e.Status, func() logrus.Fields { - return fieldsFromEntryProto(eachEntry, req.InputMask) + return fieldsFromEntryProto(ctx, eachEntry, req.InputMask) }) } @@ -432,7 +431,7 @@ func (s *Service) updateEntry(ctx context.Context, e *types.Entry, inputMask *ty } } -func fieldsFromEntryProto(proto *types.Entry, inputMask *types.EntryMask) logrus.Fields { +func fieldsFromEntryProto(ctx context.Context, proto *types.Entry, inputMask *types.EntryMask) logrus.Fields { fields := logrus.Fields{} if proto == nil { @@ -444,14 +443,14 @@ func fieldsFromEntryProto(proto *types.Entry, inputMask *types.EntryMask) logrus } if (inputMask == nil || inputMask.SpiffeId) && proto.SpiffeId != nil { - id, err := idutil.IDFromProto(proto.SpiffeId) + id, err := api.IDFromProto(ctx, proto.SpiffeId) if err == nil { fields[telemetry.SPIFFEID] = id.String() } } if (inputMask == nil || inputMask.ParentId) && proto.ParentId != nil { - id, err := idutil.IDFromProto(proto.ParentId) + id, err := api.IDFromProto(ctx, proto.ParentId) if err == nil { fields[telemetry.ParentID] = id.String() } diff --git a/pkg/server/api/id.go b/pkg/server/api/id.go index a7ebb5cd8b..79eea63589 100644 --- a/pkg/server/api/id.go +++ b/pkg/server/api/id.go @@ -19,7 +19,7 @@ var ( ) func TrustDomainMemberIDFromProto(ctx context.Context, td spiffeid.TrustDomain, protoID *types.SPIFFEID) (spiffeid.ID, error) { - id, err := idFromProto(ctx, protoID) + id, err := IDFromProto(ctx, protoID) if err != nil { return spiffeid.ID{}, err } @@ -40,7 +40,7 @@ func VerifyTrustDomainMemberID(td spiffeid.TrustDomain, id spiffeid.ID) error { } func TrustDomainAgentIDFromProto(ctx context.Context, td spiffeid.TrustDomain, protoID *types.SPIFFEID) (spiffeid.ID, error) { - id, err := idFromProto(ctx, protoID) + id, err := IDFromProto(ctx, protoID) if err != nil { return spiffeid.ID{}, err } @@ -64,7 +64,7 @@ func VerifyTrustDomainAgentID(td spiffeid.TrustDomain, id spiffeid.ID) error { } func TrustDomainWorkloadIDFromProto(ctx context.Context, td spiffeid.TrustDomain, protoID *types.SPIFFEID) (spiffeid.ID, error) { - id, err := idFromProto(ctx, protoID) + id, err := IDFromProto(ctx, protoID) if err != nil { return spiffeid.ID{}, err } @@ -103,7 +103,8 @@ func RemoveEnsureLeadingSlashLogLimit() { ensureLeadingSlashLogLimiter.SetLimit(rate.Inf) } -func idFromProto(ctx context.Context, protoID *types.SPIFFEID) (spiffeid.ID, error) { +// IDFromProto converts a SPIFFEID message into an ID type +func IDFromProto(ctx context.Context, protoID *types.SPIFFEID) (spiffeid.ID, error) { if protoID == nil { return spiffeid.ID{}, errors.New("request must specify SPIFFE ID") }