diff --git a/auth/auth.go b/auth/auth.go index 7a416d52..5d549380 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -20,10 +20,7 @@ import ( // secret keys, which are base64 encoded and separate by whitespace and/or commas. // The first key is used for signing Authorizations, and any key may verify // a presented Authorization. -func NewKeyedAuth(encodedKeys string) (interface { - pb.Authorizer - pb.Verifier -}, error) { +func NewKeyedAuth(encodedKeys string) (*KeyedAuth, error) { var keys jwt.VerificationKeySet for i, key := range strings.Fields(strings.ReplaceAll(encodedKeys, ",", " ")) { @@ -33,28 +30,30 @@ func NewKeyedAuth(encodedKeys string) (interface { keys.Keys = append(keys.Keys, b) } } - return &keySet{keys}, nil + return &KeyedAuth{keys}, nil } -type keySet struct { +// KeyedAuth implements the pb.Authorizer and pb.Verifier +// interfaces using symmetric, pre-shared keys. +type KeyedAuth struct { jwt.VerificationKeySet } -func (k *keySet) Authorize(ctx context.Context, claims pb.Claims, exp time.Duration) (context.Context, error) { +func (k *KeyedAuth) Authorize(ctx context.Context, claims pb.Claims, exp time.Duration) (context.Context, error) { var now = time.Now() claims.IssuedAt = &jwt.NumericDate{Time: now} claims.ExpiresAt = &jwt.NumericDate{Time: now.Add(exp)} var token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(k.Keys[0]) if err != nil { - return ctx, err + return nil, err } return metadata.AppendToOutgoingContext(ctx, "authorization", fmt.Sprintf("Bearer %s", token)), nil } -func (k *keySet) Verify(ctx context.Context, require pb.Capability) (context.Context, context.CancelFunc, pb.Claims, error) { +func (k *KeyedAuth) Verify(ctx context.Context, require pb.Capability) (context.Context, context.CancelFunc, pb.Claims, error) { if claims, err := verifyWithKeys(ctx, require, k.VerificationKeySet); err != nil { - return ctx, func() {}, claims, status.Error(codes.Unauthenticated, err.Error()) + return nil, func() {}, claims, status.Error(codes.Unauthenticated, err.Error()) } else { ctx, cancel := context.WithDeadline(ctx, claims.ExpiresAt.Time) return ctx, cancel, claims, nil diff --git a/broker/protocol/auth.go b/broker/protocol/auth.go index 2876c5e3..192b1c2e 100644 --- a/broker/protocol/auth.go +++ b/broker/protocol/auth.go @@ -57,18 +57,18 @@ type Verifier interface { Verify(ctx context.Context, require Capability) (context.Context, context.CancelFunc, Claims, error) } -// NewAuthJournalClient returns a JournalClient which uses the Authorizer +// NewAuthJournalClient returns an *AuthJournalClient which uses the Authorizer // to obtain and attach an Authorization bearer token to every issued request. -func NewAuthJournalClient(jc JournalClient, auth Authorizer) JournalClient { - return &authClient{auth: auth, jc: jc} +func NewAuthJournalClient(jc JournalClient, auth Authorizer) *AuthJournalClient { + return &AuthJournalClient{Authorizer: auth, JournalClient: jc} } -type authClient struct { - auth Authorizer - jc JournalClient +type AuthJournalClient struct { + Authorizer + JournalClient } -func (a *authClient) List(ctx context.Context, in *ListRequest, opts ...grpc.CallOption) (Journal_ListClient, error) { +func (a *AuthJournalClient) List(ctx context.Context, in *ListRequest, opts ...grpc.CallOption) (Journal_ListClient, error) { var claims, ok = GetClaims(ctx) if !ok { claims = Claims{ @@ -76,26 +76,26 @@ func (a *authClient) List(ctx context.Context, in *ListRequest, opts ...grpc.Cal Selector: in.Selector, } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(in.Watch)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(in.Watch)); err != nil { return nil, err } else { - return a.jc.List(ctx, in, opts...) + return a.JournalClient.List(ctx, in, opts...) } } -func (a *authClient) Apply(ctx context.Context, in *ApplyRequest, opts ...grpc.CallOption) (*ApplyResponse, error) { +func (a *AuthJournalClient) Apply(ctx context.Context, in *ApplyRequest, opts ...grpc.CallOption) (*ApplyResponse, error) { var claims, ok = GetClaims(ctx) if !ok { claims = Claims{Capability: Capability_APPLY} } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.jc.Apply(ctx, in, opts...) + return a.JournalClient.Apply(ctx, in, opts...) } } -func (a *authClient) Read(ctx context.Context, in *ReadRequest, opts ...grpc.CallOption) (Journal_ReadClient, error) { +func (a *AuthJournalClient) Read(ctx context.Context, in *ReadRequest, opts ...grpc.CallOption) (Journal_ReadClient, error) { var claims, ok = GetClaims(ctx) if !ok { claims = Claims{ @@ -105,38 +105,38 @@ func (a *authClient) Read(ctx context.Context, in *ReadRequest, opts ...grpc.Cal }, } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(true)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(true)); err != nil { return nil, err } else { - return a.jc.Read(ctx, in, opts...) + return a.JournalClient.Read(ctx, in, opts...) } } -func (a *authClient) Append(ctx context.Context, opts ...grpc.CallOption) (Journal_AppendClient, error) { +func (a *AuthJournalClient) Append(ctx context.Context, opts ...grpc.CallOption) (Journal_AppendClient, error) { var claims, ok = GetClaims(ctx) if !ok { panic("Append requires a context having WithClaims") } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(true)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(true)); err != nil { return nil, err } else { - return a.jc.Append(ctx, opts...) + return a.JournalClient.Append(ctx, opts...) } } -func (a *authClient) Replicate(ctx context.Context, opts ...grpc.CallOption) (Journal_ReplicateClient, error) { +func (a *AuthJournalClient) Replicate(ctx context.Context, opts ...grpc.CallOption) (Journal_ReplicateClient, error) { var claims, ok = GetClaims(ctx) if !ok { panic("Replicate requires a context having WithClaims") } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(true)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(true)); err != nil { return nil, err } else { - return a.jc.Replicate(ctx, opts...) + return a.JournalClient.Replicate(ctx, opts...) } } -func (a *authClient) ListFragments(ctx context.Context, in *FragmentsRequest, opts ...grpc.CallOption) (*FragmentsResponse, error) { +func (a *AuthJournalClient) ListFragments(ctx context.Context, in *FragmentsRequest, opts ...grpc.CallOption) (*FragmentsResponse, error) { var claims, ok = GetClaims(ctx) if !ok { claims = Claims{ @@ -146,10 +146,10 @@ func (a *authClient) ListFragments(ctx context.Context, in *FragmentsRequest, op }, } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.jc.ListFragments(ctx, in, opts...) + return a.JournalClient.ListFragments(ctx, in, opts...) } } @@ -243,7 +243,7 @@ func (s *authServer) ListFragments(ctx context.Context, req *FragmentsRequest) ( } var _ JournalServer = &authServer{} -var _ JournalClient = &authClient{} +var _ JournalClient = &AuthJournalClient{} type claimsCtxKey struct{} diff --git a/broker/protocol/dispatcher.go b/broker/protocol/dispatcher.go index 79261733..4cf1b63c 100644 --- a/broker/protocol/dispatcher.go +++ b/broker/protocol/dispatcher.go @@ -59,6 +59,15 @@ func WithDispatchItemRoute(ctx context.Context, dr DispatchRouter, item string, dispatchRoute{route: rt, id: id, item: item, DispatchRouter: dr}) } +// GetDispatchRoute returns a Route and ProcessSpec_ID which haven been previously attached to the Context. +func GetDispatchRoute(ctx context.Context) (Route, ProcessSpec_ID, bool) { + if dr, ok := ctx.Value(dispatchRouteCtxKey{}).(dispatchRoute); ok { + return dr.route, dr.id, true + } else { + return Route{}, ProcessSpec_ID{}, false + } +} + // DispatchRouter routes item to Routes, and observes item Routes. type DispatchRouter interface { // Route an |item| to a Route, which may be empty if the Route is unknown. diff --git a/broker/protocol/rpc_extensions.go b/broker/protocol/rpc_extensions.go index dca30eba..255d3e45 100644 --- a/broker/protocol/rpc_extensions.go +++ b/broker/protocol/rpc_extensions.go @@ -16,12 +16,16 @@ type RoutedJournalClient interface { DispatchRouter } +// ComposedRoutedJournalClient implements the RoutedJournalClient interface +// by composing separate implementations of its constituent interfaces. +type ComposedRoutedJournalClient struct { + JournalClient + DispatchRouter +} + // NewRoutedJournalClient composes a JournalClient and DispatchRouter. -func NewRoutedJournalClient(jc JournalClient, dr DispatchRouter) RoutedJournalClient { - return struct { - JournalClient - DispatchRouter - }{ +func NewRoutedJournalClient(jc JournalClient, dr DispatchRouter) *ComposedRoutedJournalClient { + return &ComposedRoutedJournalClient{ JournalClient: jc, DispatchRouter: dr, } diff --git a/consumer/protocol/auth.go b/consumer/protocol/auth.go index dca3aa61..39cef402 100644 --- a/consumer/protocol/auth.go +++ b/consumer/protocol/auth.go @@ -8,18 +8,18 @@ import ( grpc "google.golang.org/grpc" ) -// NewAuthShardClient returns a ShardClient which uses the Authorizer +// NewAuthShardClient returns an *AuthShardClient which uses the Authorizer // to obtain and attach an Authorization bearer token to every issued request. -func NewAuthShardClient(sc ShardClient, auth pb.Authorizer) ShardClient { - return &authShardClient{auth: auth, sc: sc} +func NewAuthShardClient(sc ShardClient, auth pb.Authorizer) *AuthShardClient { + return &AuthShardClient{Authorizer: auth, ShardClient: sc} } -type authShardClient struct { - auth pb.Authorizer - sc ShardClient +type AuthShardClient struct { + pb.Authorizer + ShardClient } -func (a *authShardClient) Stat(ctx context.Context, in *StatRequest, opts ...grpc.CallOption) (*StatResponse, error) { +func (a *AuthShardClient) Stat(ctx context.Context, in *StatRequest, opts ...grpc.CallOption) (*StatResponse, error) { var claims, ok = pb.GetClaims(ctx) if !ok { claims = pb.Claims{ @@ -29,14 +29,14 @@ func (a *authShardClient) Stat(ctx context.Context, in *StatRequest, opts ...grp }, } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.sc.Stat(ctx, in, opts...) + return a.ShardClient.Stat(ctx, in, opts...) } } -func (a *authShardClient) List(ctx context.Context, in *ListRequest, opts ...grpc.CallOption) (*ListResponse, error) { +func (a *AuthShardClient) List(ctx context.Context, in *ListRequest, opts ...grpc.CallOption) (*ListResponse, error) { var claims, ok = pb.GetClaims(ctx) if !ok { claims = pb.Claims{ @@ -44,26 +44,26 @@ func (a *authShardClient) List(ctx context.Context, in *ListRequest, opts ...grp Selector: in.Selector, } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.sc.List(ctx, in, opts...) + return a.ShardClient.List(ctx, in, opts...) } } -func (a *authShardClient) Apply(ctx context.Context, in *ApplyRequest, opts ...grpc.CallOption) (*ApplyResponse, error) { +func (a *AuthShardClient) Apply(ctx context.Context, in *ApplyRequest, opts ...grpc.CallOption) (*ApplyResponse, error) { var claims, ok = pb.GetClaims(ctx) if !ok { claims = pb.Claims{Capability: pb.Capability_APPLY} } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.sc.Apply(ctx, in, opts...) + return a.ShardClient.Apply(ctx, in, opts...) } } -func (a *authShardClient) GetHints(ctx context.Context, in *GetHintsRequest, opts ...grpc.CallOption) (*GetHintsResponse, error) { +func (a *AuthShardClient) GetHints(ctx context.Context, in *GetHintsRequest, opts ...grpc.CallOption) (*GetHintsResponse, error) { var claims, ok = pb.GetClaims(ctx) if !ok { claims = pb.Claims{ @@ -73,14 +73,14 @@ func (a *authShardClient) GetHints(ctx context.Context, in *GetHintsRequest, opt }, } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.sc.GetHints(ctx, in, opts...) + return a.ShardClient.GetHints(ctx, in, opts...) } } -func (a *authShardClient) Unassign(ctx context.Context, in *UnassignRequest, opts ...grpc.CallOption) (*UnassignResponse, error) { +func (a *AuthShardClient) Unassign(ctx context.Context, in *UnassignRequest, opts ...grpc.CallOption) (*UnassignResponse, error) { var claims, ok = pb.GetClaims(ctx) if !ok { claims = pb.Claims{Capability: pb.Capability_APPLY} @@ -88,10 +88,10 @@ func (a *authShardClient) Unassign(ctx context.Context, in *UnassignRequest, opt claims.Selector.Include.AddValue("id", id.String()) } } - if ctx, err := a.auth.Authorize(ctx, claims, withExp(false)); err != nil { + if ctx, err := a.Authorizer.Authorize(ctx, claims, withExp(false)); err != nil { return nil, err } else { - return a.sc.Unassign(ctx, in, opts...) + return a.ShardClient.Unassign(ctx, in, opts...) } } @@ -175,4 +175,4 @@ func (s *authServer) Unassign(ctx context.Context, req *UnassignRequest) (*Unass } var _ ShardServer = &authServer{} -var _ ShardClient = &authShardClient{} +var _ ShardClient = &AuthShardClient{}