Skip to content

Commit

Permalink
Merge pull request #252 from xmidt-org/feature/cert-info-in-handler
Browse files Browse the repository at this point in the history
Feature/cert info in handler
  • Loading branch information
schmidtw authored Nov 15, 2024
2 parents 4deb160 + 50633d1 commit af5e10f
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 4 deletions.
33 changes: 33 additions & 0 deletions token/claimBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ import (
kithttp "github.com/go-kit/kit/transport/http"
)

const (
// ClaimTrust is the name of the trust value within JWT claims issued
// by themis. This claim will be overridden based upon TLS connection state.
ClaimTrust = "trust"
)

var (
ErrRemoteURLRequired = errors.New("A URL for the remote claimer is required")
ErrMissingKey = errors.New("A key is required for all claims and metadata values")
Expand Down Expand Up @@ -172,6 +178,16 @@ func newRemoteClaimBuilder(client xhttpclient.Interface, metadata map[string]int
return &remoteClaimBuilder{endpoint: c.Endpoint(), url: r.URL, extra: metadata}, nil
}

// enforcePeerCertificate is a ClaimsBuilderFunc that overrides trust as necessary
// given the TLS peer certificates (if any)
func enforcePeerCertificate(_ context.Context, r *Request, target map[string]interface{}) error {
if len(r.ConnectionState.PeerCertificates) == 0 {
target[ClaimTrust] = 0
}

return nil
}

// NewClaimBuilders constructs a ClaimBuilders from configuration. The returned instance is typically
// used in configuration a token Factory. It can be used as a standalone service component with an endpoint.
//
Expand All @@ -182,6 +198,7 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options)
builders = ClaimBuilders{requestClaimBuilder{}}
staticClaimBuilder = make(staticClaimBuilder)
)

if o.Remote != nil { // scan the metadata looking for static values that should be applied when invoking the remote server
metadata := make(map[string]interface{})
for _, value := range o.Metadata {
Expand All @@ -200,34 +217,44 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options)
metadata[value.Key] = msg
}
}

remoteClaimBuilder, err := newRemoteClaimBuilder(client, metadata, o.Remote)
if err != nil {
return nil, err
}

builders = append(builders, remoteClaimBuilder)
}

for _, value := range o.Claims {
switch {
case len(value.Key) == 0:
return nil, ErrMissingKey

case value.IsFromHTTP():
continue

case !value.IsStatic():
return nil, fmt.Errorf("A value is required for the static claim: %s", value.Key)

default:
msg, err := value.RawMessage()
if err != nil {
return nil, err
}

staticClaimBuilder[value.Key] = msg
}
}

if len(staticClaimBuilder) > 0 {
builders = append(builders, staticClaimBuilder)
}

if o.Nonce && n != nil {
builders = append(builders, nonceClaimBuilder{n: n})
}

if !o.DisableTime {
builders = append(
builders,
Expand All @@ -238,5 +265,11 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options)
notBeforeDelta: o.NotBeforeDelta,
})
}

builders = append(
builders,
ClaimBuilderFunc(enforcePeerCertificate),
)

return builders, nil
}
5 changes: 4 additions & 1 deletion token/claimBuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ func (suite *NewClaimBuildersTestSuite) TestMinimum() {
)

suite.Equal(
map[string]interface{}{"request": 123},
map[string]interface{}{"request": 123, "trust": 0},
actual,
)
}
Expand Down Expand Up @@ -691,6 +691,7 @@ func (suite *NewClaimBuildersTestSuite) TestStatic() {
"static1": suite.rawMessage(-72.5),
"static2": suite.rawMessage([]string{"a", "b"}),
"request": 123,
"trust": 0,
},
actual,
)
Expand Down Expand Up @@ -737,6 +738,7 @@ func (suite *NewClaimBuildersTestSuite) TestNoRemote() {
"iat": suite.expectedNow.UTC().Unix(),
"nbf": suite.expectedNow.Add(15 * time.Second).UTC().Unix(),
"exp": suite.expectedNow.Add(24 * time.Hour).UTC().Unix(),
"trust": 0,
},
actual,
)
Expand Down Expand Up @@ -821,6 +823,7 @@ func (suite *NewClaimBuildersTestSuite) TestFull() {
"iat": suite.expectedNow.UTC().Unix(),
"nbf": suite.expectedNow.Add(15 * time.Second).UTC().Unix(),
"exp": suite.expectedNow.Add(24 * time.Hour).UTC().Unix(),
"trust": 0,
},
actual,
)
Expand Down
5 changes: 5 additions & 0 deletions token/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package token

import (
"context"
"crypto/tls"
"fmt"
"sync/atomic"

Expand All @@ -26,6 +27,10 @@ type Request struct {
// Metadata holds non-claim information about the request, usually garnered from the original HTTP request. This
// metadata is available to lower levels of infrastructure used by the Factory.
Metadata map[string]interface{}

// ConnectionState represents the state of any underlying TLS connection.
// For non-tls connections, this field is unset.
ConnectionState tls.ConnectionState
}

// NewRequest returns an empty, fully initialized token Request
Expand Down
17 changes: 17 additions & 0 deletions token/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ func (prb partnerIDRequestBuilder) Build(original *http.Request, tr *Request) er
return nil
}

// setConnectionState sets the tls.ConnectionState for the given request.
func setConnectionState(original *http.Request, tr *Request) error {
if cs, ok := xhttpserver.ConnectionState(original.Context()); ok {
tr.ConnectionState = cs
}

return nil
}

// NewRequestBuilders creates a RequestBuilders sequence given an Options configuration. Only claims
// and metadata that are HTTP-based are included in the results. Claims and metadata that are statically
// assigned values are handled by ClaimBuilder objects and are part of the Factory configuration.
Expand Down Expand Up @@ -238,6 +247,7 @@ func NewRequestBuilders(o Options) (RequestBuilders, error) {
)
}
}

for _, value := range o.Metadata {
switch {
case len(value.Key) == 0:
Expand All @@ -264,13 +274,20 @@ func NewRequestBuilders(o Options) (RequestBuilders, error) {
)
}
}

if o.PartnerID != nil && (len(o.PartnerID.Claim) > 0 || len(o.PartnerID.Metadata) > 0) {
rb = append(rb,
partnerIDRequestBuilder{
PartnerID: *o.PartnerID,
},
)
}

rb = append(
rb,
RequestBuilderFunc(setConnectionState),
)

return rb, nil
}

Expand Down
3 changes: 0 additions & 3 deletions token/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func testUnmarshalError(t *testing.T) {

app := fx.New(
fx.Provide(
fx.Logger(sallust.Printer{}),
config.ProvideViper(
config.Json(`
{
Expand All @@ -49,7 +48,6 @@ func testUnmarshalClaimBuilderError(t *testing.T) {

app = fx.New(
fx.Provide(
fx.Logger(sallust.Printer{}),
config.ProvideViper(
config.Json(`
{
Expand Down Expand Up @@ -84,7 +82,6 @@ func testUnmarshalFactoryError(t *testing.T) {

app = fx.New(
fx.Provide(
fx.Logger(sallust.Printer{}),
config.ProvideViper(
config.Json(`
{
Expand Down
13 changes: 13 additions & 0 deletions xhttp/xhttpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package xhttpserver

import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
Expand Down Expand Up @@ -98,6 +99,18 @@ func New(o Options, l *zap.Logger, h http.Handler) Interface {
o.Address,
l,
),

ConnContext: func(ctx context.Context, c net.Conn) context.Context {
type connectionStater interface {
ConnectionState() tls.ConnectionState
}

if cs, ok := c.(connectionStater); ok {
ctx = SetConnectionState(ctx, cs.ConnectionState())
}

return ctx
},
}

if o.LogConnectionState {
Expand Down
1 change: 1 addition & 0 deletions xhttp/xhttpserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ func testNewSimple(t *testing.T) {
assert.Greater(output.Len(), 0)

assert.Nil(s.(*http.Server).ConnState)
assert.NotNil(s.(*http.Server).ConnContext)
}

func testNewFull(t *testing.T) {
Expand Down
18 changes: 18 additions & 0 deletions xhttp/xhttpserver/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package xhttpserver

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand Down Expand Up @@ -249,3 +250,20 @@ func NewTlsConfig(t *Tls, extra ...PeerVerifier) (*tls.Config, error) {
tc.BuildNameToCertificate() // nolint: staticcheck
return tc, nil
}

type connectionStateKey struct{}

// ConnectionState returns the tls.ConnectionState from the given context.
func ConnectionState(ctx context.Context) (cs tls.ConnectionState, present bool) {
cs, present = ctx.Value(connectionStateKey{}).(tls.ConnectionState)
return
}

// SetConnectionState associates a tls.ConnectionState with the given context.
func SetConnectionState(ctx context.Context, cs tls.ConnectionState) context.Context {
return context.WithValue(
ctx,
connectionStateKey{},
cs,
)
}

0 comments on commit af5e10f

Please sign in to comment.