diff --git a/token/claimBuilder.go b/token/claimBuilder.go index e7beb58..d6b9a04 100644 --- a/token/claimBuilder.go +++ b/token/claimBuilder.go @@ -17,6 +17,12 @@ import ( kithttp "github.com/go-kit/kit/transport/http" ) +const ( + // ClaimTrust is the name of the trust value within JWT claims issued + // by themis. This claim will be overridden based upon TLS connection state. + ClaimTrust = "trust" +) + var ( ErrRemoteURLRequired = errors.New("A URL for the remote claimer is required") ErrMissingKey = errors.New("A key is required for all claims and metadata values") @@ -172,6 +178,16 @@ func newRemoteClaimBuilder(client xhttpclient.Interface, metadata map[string]int return &remoteClaimBuilder{endpoint: c.Endpoint(), url: r.URL, extra: metadata}, nil } +// enforcePeerCertificate is a ClaimsBuilderFunc that overrides trust as necessary +// given the TLS peer certificates (if any) +func enforcePeerCertificate(_ context.Context, r *Request, target map[string]interface{}) error { + if len(r.ConnectionState.PeerCertificates) == 0 { + target[ClaimTrust] = 0 + } + + return nil +} + // NewClaimBuilders constructs a ClaimBuilders from configuration. The returned instance is typically // used in configuration a token Factory. It can be used as a standalone service component with an endpoint. // @@ -182,6 +198,7 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options) builders = ClaimBuilders{requestClaimBuilder{}} staticClaimBuilder = make(staticClaimBuilder) ) + if o.Remote != nil { // scan the metadata looking for static values that should be applied when invoking the remote server metadata := make(map[string]interface{}) for _, value := range o.Metadata { @@ -200,34 +217,44 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options) metadata[value.Key] = msg } } + remoteClaimBuilder, err := newRemoteClaimBuilder(client, metadata, o.Remote) if err != nil { return nil, err } + builders = append(builders, remoteClaimBuilder) } + for _, value := range o.Claims { switch { case len(value.Key) == 0: return nil, ErrMissingKey + case value.IsFromHTTP(): continue + case !value.IsStatic(): return nil, fmt.Errorf("A value is required for the static claim: %s", value.Key) + default: msg, err := value.RawMessage() if err != nil { return nil, err } + staticClaimBuilder[value.Key] = msg } } + if len(staticClaimBuilder) > 0 { builders = append(builders, staticClaimBuilder) } + if o.Nonce && n != nil { builders = append(builders, nonceClaimBuilder{n: n}) } + if !o.DisableTime { builders = append( builders, @@ -238,5 +265,11 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options) notBeforeDelta: o.NotBeforeDelta, }) } + + builders = append( + builders, + ClaimBuilderFunc(enforcePeerCertificate), + ) + return builders, nil } diff --git a/token/claimBuilder_test.go b/token/claimBuilder_test.go index eb363b8..93a4fa1 100644 --- a/token/claimBuilder_test.go +++ b/token/claimBuilder_test.go @@ -545,7 +545,7 @@ func (suite *NewClaimBuildersTestSuite) TestMinimum() { ) suite.Equal( - map[string]interface{}{"request": 123}, + map[string]interface{}{"request": 123, "trust": 0}, actual, ) } @@ -691,6 +691,7 @@ func (suite *NewClaimBuildersTestSuite) TestStatic() { "static1": suite.rawMessage(-72.5), "static2": suite.rawMessage([]string{"a", "b"}), "request": 123, + "trust": 0, }, actual, ) @@ -737,6 +738,7 @@ func (suite *NewClaimBuildersTestSuite) TestNoRemote() { "iat": suite.expectedNow.UTC().Unix(), "nbf": suite.expectedNow.Add(15 * time.Second).UTC().Unix(), "exp": suite.expectedNow.Add(24 * time.Hour).UTC().Unix(), + "trust": 0, }, actual, ) @@ -821,6 +823,7 @@ func (suite *NewClaimBuildersTestSuite) TestFull() { "iat": suite.expectedNow.UTC().Unix(), "nbf": suite.expectedNow.Add(15 * time.Second).UTC().Unix(), "exp": suite.expectedNow.Add(24 * time.Hour).UTC().Unix(), + "trust": 0, }, actual, ) diff --git a/token/factory.go b/token/factory.go index 6d7c47d..93785af 100644 --- a/token/factory.go +++ b/token/factory.go @@ -4,6 +4,7 @@ package token import ( "context" + "crypto/tls" "fmt" "sync/atomic" @@ -26,6 +27,10 @@ type Request struct { // Metadata holds non-claim information about the request, usually garnered from the original HTTP request. This // metadata is available to lower levels of infrastructure used by the Factory. Metadata map[string]interface{} + + // ConnectionState represents the state of any underlying TLS connection. + // For non-tls connections, this field is unset. + ConnectionState tls.ConnectionState } // NewRequest returns an empty, fully initialized token Request diff --git a/token/transport.go b/token/transport.go index 7e0a944..7c8a774 100644 --- a/token/transport.go +++ b/token/transport.go @@ -208,6 +208,15 @@ func (prb partnerIDRequestBuilder) Build(original *http.Request, tr *Request) er return nil } +// setConnectionState sets the tls.ConnectionState for the given request. +func setConnectionState(original *http.Request, tr *Request) error { + if cs, ok := xhttpserver.ConnectionState(original.Context()); ok { + tr.ConnectionState = cs + } + + return nil +} + // NewRequestBuilders creates a RequestBuilders sequence given an Options configuration. Only claims // and metadata that are HTTP-based are included in the results. Claims and metadata that are statically // assigned values are handled by ClaimBuilder objects and are part of the Factory configuration. @@ -238,6 +247,7 @@ func NewRequestBuilders(o Options) (RequestBuilders, error) { ) } } + for _, value := range o.Metadata { switch { case len(value.Key) == 0: @@ -264,6 +274,7 @@ func NewRequestBuilders(o Options) (RequestBuilders, error) { ) } } + if o.PartnerID != nil && (len(o.PartnerID.Claim) > 0 || len(o.PartnerID.Metadata) > 0) { rb = append(rb, partnerIDRequestBuilder{ @@ -271,6 +282,12 @@ func NewRequestBuilders(o Options) (RequestBuilders, error) { }, ) } + + rb = append( + rb, + RequestBuilderFunc(setConnectionState), + ) + return rb, nil } diff --git a/token/unmarshal_test.go b/token/unmarshal_test.go index 9a66dc2..aed90bb 100644 --- a/token/unmarshal_test.go +++ b/token/unmarshal_test.go @@ -22,7 +22,6 @@ func testUnmarshalError(t *testing.T) { app := fx.New( fx.Provide( - fx.Logger(sallust.Printer{}), config.ProvideViper( config.Json(` { @@ -49,7 +48,6 @@ func testUnmarshalClaimBuilderError(t *testing.T) { app = fx.New( fx.Provide( - fx.Logger(sallust.Printer{}), config.ProvideViper( config.Json(` { @@ -84,7 +82,6 @@ func testUnmarshalFactoryError(t *testing.T) { app = fx.New( fx.Provide( - fx.Logger(sallust.Printer{}), config.ProvideViper( config.Json(` { diff --git a/xhttp/xhttpserver/server.go b/xhttp/xhttpserver/server.go index ca5a3aa..05fe050 100644 --- a/xhttp/xhttpserver/server.go +++ b/xhttp/xhttpserver/server.go @@ -4,6 +4,7 @@ package xhttpserver import ( "context" + "crypto/tls" "net" "net/http" "time" @@ -98,6 +99,18 @@ func New(o Options, l *zap.Logger, h http.Handler) Interface { o.Address, l, ), + + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + type connectionStater interface { + ConnectionState() tls.ConnectionState + } + + if cs, ok := c.(connectionStater); ok { + ctx = SetConnectionState(ctx, cs.ConnectionState()) + } + + return ctx + }, } if o.LogConnectionState { diff --git a/xhttp/xhttpserver/server_test.go b/xhttp/xhttpserver/server_test.go index f893e72..be89352 100644 --- a/xhttp/xhttpserver/server_test.go +++ b/xhttp/xhttpserver/server_test.go @@ -212,6 +212,7 @@ func testNewSimple(t *testing.T) { assert.Greater(output.Len(), 0) assert.Nil(s.(*http.Server).ConnState) + assert.NotNil(s.(*http.Server).ConnContext) } func testNewFull(t *testing.T) { diff --git a/xhttp/xhttpserver/tls.go b/xhttp/xhttpserver/tls.go index 2caebce..4a8c908 100644 --- a/xhttp/xhttpserver/tls.go +++ b/xhttp/xhttpserver/tls.go @@ -3,6 +3,7 @@ package xhttpserver import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -249,3 +250,20 @@ func NewTlsConfig(t *Tls, extra ...PeerVerifier) (*tls.Config, error) { tc.BuildNameToCertificate() // nolint: staticcheck return tc, nil } + +type connectionStateKey struct{} + +// ConnectionState returns the tls.ConnectionState from the given context. +func ConnectionState(ctx context.Context) (cs tls.ConnectionState, present bool) { + cs, present = ctx.Value(connectionStateKey{}).(tls.ConnectionState) + return +} + +// SetConnectionState associates a tls.ConnectionState with the given context. +func SetConnectionState(ctx context.Context, cs tls.ConnectionState) context.Context { + return context.WithValue( + ctx, + connectionStateKey{}, + cs, + ) +}