diff --git a/pkg/client/zitadel/client.go b/pkg/client/zitadel/client.go index 3fce3d5..080e65d 100644 --- a/pkg/client/zitadel/client.go +++ b/pkg/client/zitadel/client.go @@ -20,6 +20,8 @@ type Connection struct { scopes []string orgID string insecure bool + unaryInterceptors []grpc.UnaryClientInterceptor + streamInterceptors []grpc.StreamClientInterceptor *grpc.ClientConn } @@ -37,18 +39,19 @@ func NewConnection(scopes []string, options ...Option) (*Connection, error) { } } - unaryInterceptors, streamInterceptors, err := interceptors(c.issuer, c.orgID, c.scopes, c.jwtProfileTokenSource) + err := c.setInterceptors(c.issuer, c.orgID, c.scopes, c.jwtProfileTokenSource) if err != nil { return nil, err } dialOptions := []grpc.DialOption{ grpc.WithChainUnaryInterceptor( - unaryInterceptors..., + c.unaryInterceptors..., ), grpc.WithChainStreamInterceptor( - streamInterceptors..., + c.streamInterceptors..., ), } + opt, err := transportOption(c.api, c.insecure) if err != nil { return nil, err @@ -63,19 +66,20 @@ func NewConnection(scopes []string, options ...Option) (*Connection, error) { return c, nil } -func interceptors(issuer, orgID string, scopes []string, jwtProfileTokenSource middleware.JWTProfileTokenSource) ([]grpc.UnaryClientInterceptor, []grpc.StreamClientInterceptor, error) { +func (c *Connection) setInterceptors(issuer, orgID string, scopes []string, jwtProfileTokenSource middleware.JWTProfileTokenSource) error { auth, err := middleware.NewAuthenticator(issuer, jwtProfileTokenSource, scopes...) if err != nil { - return nil, nil, err + return err } - unaryInterceptors := []grpc.UnaryClientInterceptor{auth.Unary()} - streamInterceptors := []grpc.StreamClientInterceptor{auth.Stream()} + + c.unaryInterceptors = append(c.unaryInterceptors, auth.Unary()) + c.streamInterceptors = append(c.streamInterceptors, auth.Stream()) if orgID != "" { org := middleware.NewOrgInterceptor(orgID) - unaryInterceptors = append(unaryInterceptors, org.Unary()) - streamInterceptors = append(streamInterceptors, org.Stream()) + c.unaryInterceptors = append(c.unaryInterceptors, org.Unary()) + c.streamInterceptors = append(c.streamInterceptors, org.Stream()) } - return unaryInterceptors, streamInterceptors, nil + return nil } func transportOption(api string, insecure bool) (grpc.DialOption, error) { @@ -152,3 +156,19 @@ func WithInsecure() func(*Connection) error { return nil } } + +//WithUnaryInterceptors adds non ZITADEL specific interceptors to the connection +func WithUnaryInterceptors(interceptors ...grpc.UnaryClientInterceptor) func(*Connection) error { + return func(client *Connection) error { + client.unaryInterceptors = append(client.unaryInterceptors, interceptors...) + return nil + } +} + +//WithStreamInterceptors adds non ZITADEL specific interceptors to the connection +func WithStreamInterceptors(interceptors ...grpc.StreamClientInterceptor) func(*Connection) error { + return func(client *Connection) error { + client.streamInterceptors = append(client.streamInterceptors, interceptors...) + return nil + } +}