Skip to content

Commit

Permalink
refactor: add required field checks, fix: missing method_name and res…
Browse files Browse the repository at this point in the history
…ource_name check (#430)

* refactor: add required field checks, fix: missing method_name check

* fix integ test failure

* address comments

* add comments

* address feedback

* fix lint

* fix nit
  • Loading branch information
sqin2019 authored Aug 3, 2023
1 parent d7994e6 commit 4035c08
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 115 deletions.
7 changes: 5 additions & 2 deletions clients/go/pkg/audit/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1070,15 +1070,18 @@ func TestStreamInterceptor(t *testing.T) {
},
},
handler: func(srv interface{}, ss grpc.ServerStream) error {
logReq, _ := LogReqFromCtx(ss.Context())
logReq.Payload.ResourceName = "ExampleResourceName"
return grpcstatus.Error(codes.Internal, "something is wrong")
},
wantErrSubstr: "something is wrong",
wantLogReqs: []*api.AuditLogRequest{
{
Type: api.AuditLogRequest_DATA_ACCESS,
Payload: &capi.AuditLog{
ServiceName: "ExampleService",
MethodName: "/ExampleService/ExampleMethod",
ServiceName: "ExampleService",
MethodName: "/ExampleService/ExampleMethod",
ResourceName: "ExampleResourceName",
AuthenticationInfo: &capi.AuthenticationInfo{
PrincipalEmail: "user@example.com",
},
Expand Down
34 changes: 3 additions & 31 deletions clients/go/pkg/audit/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ package audit
import (
"context"
"fmt"
"strings"

api "github.com/abcxyz/lumberjack/clients/go/apis/v1alpha1"
"github.com/abcxyz/lumberjack/clients/go/pkg/auditerrors"
"github.com/abcxyz/lumberjack/pkg/validation"
)

// RequestValidator validates log request fields.
Expand All @@ -46,36 +46,8 @@ func (p *RequestValidator) process(ctx context.Context, logReq *api.AuditLogRequ
return fmt.Errorf("AuditLogRequest cannot be nil")
}

if logReq.Payload == nil {
return fmt.Errorf("AuditLogRequest.Payload cannot be nil")
}

if logReq.Payload.ServiceName == "" {
return fmt.Errorf("ServiceName cannot be empty")
}

if logReq.Payload.AuthenticationInfo == nil {
return fmt.Errorf("AuthenticationInfo cannot be nil")
}

email := logReq.Payload.AuthenticationInfo.PrincipalEmail
if err := p.validateEmail(email); err != nil {
return err
}

return nil
}

// This method is intended to validate that the email associated with the
// authentication request has the correct format and in a valid domain.
func (p *RequestValidator) validateEmail(email string) error {
if email == "" {
return fmt.Errorf("PrincipalEmail cannot be empty")
}

parts := strings.Split(email, "@")
if len(parts) != 2 || parts[1] == "" {
return fmt.Errorf("PrincipalEmail %q is malformed", email)
if err := validation.ValidateAuditLog(logReq.Payload); err != nil {
return fmt.Errorf("AuditLogRequest does not have a valid payload: %w", err)
}
return nil
}
71 changes: 0 additions & 71 deletions clients/go/pkg/audit/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"google.golang.org/genproto/googleapis/cloud/audit"
"google.golang.org/protobuf/testing/protocmp"

api "github.com/abcxyz/lumberjack/clients/go/apis/v1alpha1"
Expand Down Expand Up @@ -53,76 +52,6 @@ func TestRequestValidation_Process(t *testing.T) {
name: "should_error_when_logReq_is_nil",
wantErr: auditerrors.ErrInvalidRequest,
},
{
name: "should_error_when_authInfo_is_nil",
logReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "test-service",
},
},
wantLogReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "test-service",
},
},
wantErr: auditerrors.ErrInvalidRequest,
},
{
name: "should_error_when_auth_email_is_nil",
logReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "test-service",
AuthenticationInfo: &audit.AuthenticationInfo{},
},
},
wantLogReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "test-service",
AuthenticationInfo: &audit.AuthenticationInfo{},
},
},
wantErr: auditerrors.ErrInvalidRequest,
},
{
name: "should_error_when_auth_email_has_no_domain",
logReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "test-service",
AuthenticationInfo: &audit.AuthenticationInfo{
PrincipalEmail: "user",
},
},
},
wantLogReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "test-service",
AuthenticationInfo: &audit.AuthenticationInfo{
PrincipalEmail: "user",
},
},
},
wantErr: auditerrors.ErrInvalidRequest,
},
{
name: "should_error_when_serviceName_is_empty",
logReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "",
AuthenticationInfo: &audit.AuthenticationInfo{
PrincipalEmail: "user@test.com",
},
},
},
wantLogReq: &api.AuditLogRequest{
Payload: &audit.AuditLog{
ServiceName: "",
AuthenticationInfo: &audit.AuthenticationInfo{
PrincipalEmail: "user@test.com",
},
},
},
wantErr: auditerrors.ErrInvalidRequest,
},
}

for _, tc := range tests {
Expand Down
1 change: 1 addition & 0 deletions clients/go/pkg/testutil/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func NewRequest(opts ...RequestOptions) *api.AuditLogRequest {
request := &api.AuditLogRequest{
Type: api.AuditLogRequest_DATA_ACCESS,
Payload: &audit.AuditLog{
MethodName: "test-method",
ServiceName: "test-service",
ResourceName: "test-resource",
AuthenticationInfo: &audit.AuthenticationInfo{
Expand Down
3 changes: 2 additions & 1 deletion clients/go/test/shell/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
auditLogRequest := &v1alpha1.AuditLogRequest{
Type: v1alpha1.AuditLogRequest_DATA_ACCESS,
Payload: &cal.AuditLog{
ServiceName: serviceName,
ServiceName: serviceName,
ResourceName: traceID,
AuthenticationInfo: &cal.AuthenticationInfo{
PrincipalEmail: email,
},
Expand Down
86 changes: 79 additions & 7 deletions pkg/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,58 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Package validation provides untils for lumberjack/data access logs
// validation.
// Package validation provides utils for lumberjack/data access logs validation.
package validation

import (
"errors"
"fmt"
"strings"

"google.golang.org/protobuf/encoding/protojson"

lepb "cloud.google.com/go/logging/apiv2/loggingpb"
cal "google.golang.org/genproto/googleapis/cloud/audit"
)

var requiredLabels = map[string]struct{}{
"environment": {},
"accessing_process_name": {},
}

// Validator validates a lumberjack log entry.
type Validator func(le *lepb.LogEntry) error

// Validate validates a json string representation of a lumberjack log.
func Validate(log string) error {
func Validate(log string, extra ...Validator) error {
var logEntry lepb.LogEntry
if err := protojson.Unmarshal([]byte(log), &logEntry); err != nil {
return fmt.Errorf("failed to parse log entry as JSON: %w", err)
}

if err := validatePayload(&logEntry); err != nil {
return fmt.Errorf("failed to validate payload: %w", err)
var retErr error
for _, v := range append([]Validator{validatePayload}, extra...) {
retErr = errors.Join(retErr, v(&logEntry))
}
return retErr
}

// TODO (#427): add required fields check.
return nil
// ValidateLabels checks required lumberjack labels.
func ValidateLabels(le *lepb.LogEntry) error {
if le.Labels == nil {
return fmt.Errorf("missing labels")
}

var retErr error
for k := range requiredLabels {
if _, ok := le.Labels[k]; !ok {
retErr = errors.Join(retErr, fmt.Errorf("missing required label: %q", k))
}
}
return retErr
}

// Required audit log payload check for lumberjack logs.
func validatePayload(logEntry *lepb.LogEntry) error {
payload := logEntry.GetJsonPayload()
if payload == nil {
Expand All @@ -54,5 +78,53 @@ func validatePayload(logEntry *lepb.LogEntry) error {
if err := protojson.Unmarshal(val, &al); err != nil {
return fmt.Errorf("failed to parse JSON payload: %w", err)
}
if err := ValidateAuditLog(&al); err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
return nil
}

// ValidateAuditLog validates the audit log payload for lumberjack.
func ValidateAuditLog(payload *cal.AuditLog) error {
if payload == nil {
return fmt.Errorf("audit log payload cannot be nil")
}

var retErr error
if payload.MethodName == "" {
retErr = errors.Join(retErr, fmt.Errorf("MethodName cannot be empty"))
}

if payload.ServiceName == "" {
retErr = errors.Join(retErr, fmt.Errorf("ServiceName cannot be empty"))
}

if payload.ResourceName == "" {
retErr = errors.Join(retErr, fmt.Errorf("ResourceName cannot be empty"))
}

if payload.AuthenticationInfo == nil {
retErr = errors.Join(retErr, fmt.Errorf("AuthenticationInfo cannot be nil"))
} else {
email := payload.AuthenticationInfo.PrincipalEmail
if err := validateEmail(email); err != nil {
retErr = errors.Join(retErr, err)
}
}

return retErr
}

// This method is intended to validate that the email associated with the
// authentication request has the correct format and in a valid domain.
func validateEmail(email string) error {
if email == "" {
return fmt.Errorf("PrincipalEmail cannot be empty")
}

parts := strings.Split(email, "@")
if len(parts) != 2 || parts[1] == "" {
return fmt.Errorf("PrincipalEmail %q is malformed", email)
}
return nil
}
Loading

0 comments on commit 4035c08

Please sign in to comment.