diff --git a/acme/api/order.go b/acme/api/order.go index beda4e5c4..14549e75e 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" + "fmt" "net" "net/http" "strings" @@ -49,30 +50,86 @@ func (n *NewOrderRequest) Validate() error { if id.Value == "" { return acme.NewError(acme.ErrorMalformedType, "permanent identifier cannot be empty") } - case acme.WireUser: - _, err := wire.ParseUserID([]byte(id.Value)) - if err != nil { - return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing Wire ID") - } - case acme.WireDevice: - wireID, err := wire.ParseDeviceID([]byte(id.Value)) - if err != nil { - return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing Wire ID") - } - if _, err := wire.ParseClientID(wireID.ClientID); err != nil { - return acme.WrapError(acme.ErrorMalformedType, err, "invalid Wire client ID %q", wireID.ClientID) - } + case acme.WireUser, acme.WireDevice: + // validation of Wire identifiers is performed in `validateWireIdentifiers`, but + // marked here as known and supported types. + continue default: return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type) } + } - // TODO(hs): add some validations for DNS domains? - // TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 + if err := n.validateWireIdentifiers(); err != nil { + return acme.WrapError(acme.ErrorMalformedType, err, "failed validating Wire identifiers") } + // TODO(hs): add some validations for DNS domains? + // TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1 + return nil } +func (n *NewOrderRequest) validateWireIdentifiers() error { + if !n.hasWireIdentifiers() { + return nil + } + + userIdentifiers := identifiersOfType(acme.WireUser, n.Identifiers) + deviceIdentifiers := identifiersOfType(acme.WireDevice, n.Identifiers) + + if len(userIdentifiers) != 1 { + return fmt.Errorf("expected exactly one Wire UserID identifier; got %d", len(userIdentifiers)) + } + if len(deviceIdentifiers) != 1 { + return fmt.Errorf("expected exactly one Wire DeviceID identifier, got %d", len(deviceIdentifiers)) + } + + wireUserID, err := wire.ParseUserID(userIdentifiers[0].Value) + if err != nil { + return fmt.Errorf("failed parsing Wire UserID: %w", err) + } + + wireDeviceID, err := wire.ParseDeviceID(deviceIdentifiers[0].Value) + if err != nil { + return fmt.Errorf("failed parsing Wire DeviceID: %w", err) + } + if _, err := wire.ParseClientID(wireDeviceID.ClientID); err != nil { + return fmt.Errorf("invalid Wire client ID %q: %w", wireDeviceID.ClientID, err) + } + + switch { + case wireUserID.Domain != wireDeviceID.Domain: + return fmt.Errorf("UserID domain %q does not match DeviceID domain %q", wireUserID.Domain, wireDeviceID.Domain) + case wireUserID.Name != wireDeviceID.Name: + return fmt.Errorf("UserID name %q does not match DeviceID name %q", wireUserID.Name, wireDeviceID.Name) + case wireUserID.Handle != wireDeviceID.Handle: + return fmt.Errorf("UserID handle %q does not match DeviceID handle %q", wireUserID.Handle, wireDeviceID.Handle) + } + + return nil +} + +// hasWireIdentifiers returns whether the [NewOrderRequest] contains +// Wire identifiers. +func (n *NewOrderRequest) hasWireIdentifiers() bool { + for _, i := range n.Identifiers { + if i.Type == acme.WireUser || i.Type == acme.WireDevice { + return true + } + } + return false +} + +// identifiersOfType returns the Identifiers that are of type typ. +func identifiersOfType(typ acme.IdentifierType, ids []acme.Identifier) (result []acme.Identifier) { + for _, id := range ids { + if id.Type == typ { + result = append(result, id) + } + } + return +} + // FinalizeRequest captures the body for a Finalize order request. type FinalizeRequest struct { CSR string `json:"csr"` @@ -284,22 +341,14 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { if err != nil { return acme.WrapErrorISE(err, "failed getting Wire options") } - var targetProvider interface{ EvaluateTarget(string) (string, error) } - switch typ { - case acme.WIREOIDC01: - targetProvider = wireOptions.GetOIDCOptions() - default: - return acme.NewError(acme.ErrorMalformedType, "unsupported type %q", typ) - } - - target, err = targetProvider.EvaluateTarget("") + target, err = wireOptions.GetOIDCOptions().EvaluateTarget("") // TODO(hs): determine if required by Wire if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "invalid Go template registered for 'target'") } case acme.WireDevice: - wireID, err := wire.ParseDeviceID([]byte(az.Identifier.Value)) + wireID, err := wire.ParseDeviceID(az.Identifier.Value) if err != nil { - return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing WireUser") + return acme.WrapError(acme.ErrorMalformedType, err, "failed parsing WireDevice") } clientID, err := wire.ParseClientID(wireID.ClientID) if err != nil { @@ -309,15 +358,7 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { if err != nil { return acme.WrapErrorISE(err, "failed getting Wire options") } - var targetProvider interface{ EvaluateTarget(string) (string, error) } - switch typ { - case acme.WIREDPOP01: - targetProvider = wireOptions.GetDPOPOptions() - default: - return acme.NewError(acme.ErrorMalformedType, "unsupported type %q", typ) - } - - target, err = targetProvider.EvaluateTarget(clientID.DeviceID) + target, err = wireOptions.GetDPOPOptions().EvaluateTarget(clientID.DeviceID) if err != nil { return acme.WrapError(acme.ErrorMalformedType, err, "invalid Go template registered for 'target'") } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index da95bc23d..9daa2f70a 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -25,6 +25,9 @@ import ( "github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner/wire" + + sassert "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewOrderRequest_Validate(t *testing.T) { @@ -101,30 +104,33 @@ func TestNewOrderRequest_Validate(t *testing.T) { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "wireapp-device", Value: "{}"}, + {Type: "wireapp-user", Value: `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, + {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, }, - err: acme.NewError(acme.ErrorMalformedType, `invalid Wire client ID "": invalid Wire client ID URI "": error parsing : scheme is missing`), + err: acme.NewError(acme.ErrorMalformedType, `failed validating Wire identifiers: invalid Wire client ID "example.com": invalid Wire client ID URI "example.com": error parsing example.com: scheme is missing`), } }, "fail/bad-identifier/wireapp-wrong-scheme": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ + {Type: "wireapp-user", Value: `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "nowireapp://example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, }, - err: acme.NewError(acme.ErrorMalformedType, `invalid Wire client ID "nowireapp://example.com": invalid Wire client ID scheme "nowireapp"; expected "wireapp"`), + err: acme.NewError(acme.ErrorMalformedType, `failed validating Wire identifiers: invalid Wire client ID "nowireapp://example.com": invalid Wire client ID scheme "nowireapp"; expected "wireapp"`), } }, "fail/bad-identifier/wireapp-invalid-user-parts": func(t *testing.T) test { return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ + {Type: "wireapp-user", Value: `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://user-device@example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, }, - err: acme.NewError(acme.ErrorMalformedType, `invalid Wire client ID "wireapp://user-device@example.com": invalid Wire client ID username "user-device"`), + err: acme.NewError(acme.ErrorMalformedType, `failed validating Wire identifiers: invalid Wire client ID "wireapp://user-device@example.com": invalid Wire client ID username "user-device"`), } }, "ok": func(t *testing.T) test { @@ -205,27 +211,13 @@ func TestNewOrderRequest_Validate(t *testing.T) { naf: naf, } }, - "ok/wireapp-user": func(t *testing.T) test { + "ok/wireapp": func(t *testing.T) test { nbf := time.Now().UTC().Add(time.Minute) naf := time.Now().UTC().Add(5 * time.Minute) return test{ nor: &NewOrderRequest{ Identifiers: []acme.Identifier{ {Type: "wireapp-user", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, - }, - NotAfter: naf, - NotBefore: nbf, - }, - nbf: nbf, - naf: naf, - } - }, - "ok/wireapp-device": func(t *testing.T) test { - nbf := time.Now().UTC().Add(time.Minute) - naf := time.Now().UTC().Add(5 * time.Minute) - return test{ - nor: &NewOrderRequest{ - Identifiers: []acme.Identifier{ {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, }, NotAfter: naf, @@ -239,30 +231,30 @@ func TestNewOrderRequest_Validate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - if err := tc.nor.Validate(); err != nil { - if assert.NotNil(t, err) { - var ae *acme.Error - if assert.True(t, errors.As(err, &ae)) { - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } + err := tc.nor.Validate() + if tc.err != nil { + assert.Error(t, err) + var ae *acme.Error + if assert.True(t, errors.As(err, &ae)) { + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) } + return + } + + assert.NoError(t, err) + if tc.nbf.IsZero() { + assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute))) + assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute))) } else { - if assert.Nil(t, tc.err) { - if tc.nbf.IsZero() { - assert.True(t, tc.nor.NotBefore.Before(time.Now().Add(time.Minute))) - assert.True(t, tc.nor.NotBefore.After(time.Now().Add(-time.Minute))) - } else { - assert.Equals(t, tc.nor.NotBefore, tc.nbf) - } - if tc.naf.IsZero() { - assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour))) - assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute))) - } else { - assert.Equals(t, tc.nor.NotAfter, tc.naf) - } - } + assert.Equals(t, tc.nor.NotBefore, tc.nbf) + } + if tc.naf.IsZero() { + assert.True(t, tc.nor.NotAfter.Before(time.Now().Add(24*time.Hour))) + assert.True(t, tc.nor.NotAfter.After(time.Now().Add(24*time.Hour-time.Minute))) + } else { + assert.Equals(t, tc.nor.NotAfter, tc.naf) } }) } @@ -564,6 +556,37 @@ func TestHandler_GetOrder(t *testing.T) { func TestHandler_newAuthorization(t *testing.T) { defaultProvisioner := newProv() + fakeKey := `-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= +-----END PUBLIC KEY-----` + wireProvisioner := newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wire.Options{ + OIDC: &wire.OIDCOptions{ + Provider: &wire.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wire.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wire.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + }) + wireProvisionerFailOptions := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + Options: &provisioner.Options{}, + Challenges: []provisioner.ACMEChallenge{ + provisioner.WIREOIDC_01, + provisioner.WIREDPOP_01, + }, + } type test struct { az *acme.Authorization prov acme.Provisioner @@ -591,8 +614,13 @@ func TestHandler_newAuthorization(t *testing.T) { return errors.New("force") }, }, - az: az, - err: acme.NewErrorISE("error creating challenge: force"), + az: az, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Err: errors.New("error creating challenge: force"), + Detail: "The server experienced an internal error", + Status: 500, + }, } }, "fail/error-db.CreateAuthorization": func(t *testing.T) test { @@ -646,8 +674,101 @@ func TestHandler_newAuthorization(t *testing.T) { return errors.New("force") }, }, - az: az, - err: acme.NewErrorISE("error creating authorization: force"), + az: az, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Err: errors.New("error creating authorization: force"), + Detail: "The server experienced an internal error", + Status: 500, + }, + } + }, + "fail/wireapp-user-options": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "wireapp-user", + Value: "wireapp://%40alice.smith.qa@example.com", + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + return test{ + prov: wireProvisionerFailOptions, + db: &acme.MockDB{}, + az: az, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Err: errors.New("failed getting Wire options: no Wire options available"), + Detail: "The server experienced an internal error", + Status: 500, + }, + } + }, + "fail/wireapp-device-parse-id": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "wireapp-device", + Value: `{"name}`, + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + return test{ + prov: wireProvisioner, + db: &acme.MockDB{}, + az: az, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:malformed", + Err: errors.New("failed parsing WireDevice: unexpected end of JSON input"), + Detail: "The request message was malformed", + Status: 400, + }, + } + }, + "fail/wireapp-device-parse-client-id": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "wireapp-device", + Value: `{"name": "device", "domain": "wire.com", "client-id": "CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`, + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + return test{ + prov: wireProvisioner, + db: &acme.MockDB{}, + az: az, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:malformed", + Err: errors.New("failed parsing ClientID: invalid Wire client ID URI \"CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com\": error parsing CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com: scheme is missing"), + Detail: "The request message was malformed", + Status: 400, + }, + } + }, + "fail/wireapp-device-options": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "wireapp-device", + Value: `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`, + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + return test{ + prov: wireProvisionerFailOptions, + db: &acme.MockDB{}, + az: az, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Err: errors.New("failed getting Wire options: no Wire options available"), + Detail: "The server experienced an internal error", + Status: 500, + }, } }, "ok/no-wildcard": func(t *testing.T) test { @@ -816,12 +937,12 @@ func TestHandler_newAuthorization(t *testing.T) { az: az, } }, - "ok/wire": func(t *testing.T) test { + "ok/wireapp-user": func(t *testing.T) test { az := &acme.Authorization{ AccountID: "accID", Identifier: acme.Identifier{ - Type: "wireapp", - Value: "wireapp://user!client@domain", + Type: "wireapp-user", + Value: "wireapp://%40alice.smith.qa@example.com", }, Status: acme.StatusPending, ExpiresAt: clock.Now(), @@ -829,12 +950,12 @@ func TestHandler_newAuthorization(t *testing.T) { count := 0 var ch1 **acme.Challenge return test{ - prov: defaultProvisioner, + prov: wireProvisioner, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { switch count { case 0: - ch.ID = "wireapp" + ch.ID = "wireapp-user" assert.Equals(t, ch.Type, acme.WIREOIDC01) ch1 = &ch default: @@ -863,31 +984,73 @@ func TestHandler_newAuthorization(t *testing.T) { az: az, } }, + "ok/wireapp-device": func(t *testing.T) test { + az := &acme.Authorization{ + AccountID: "accID", + Identifier: acme.Identifier{ + Type: "wireapp-device", + Value: `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}`, + }, + Status: acme.StatusPending, + ExpiresAt: clock.Now(), + } + count := 0 + var ch1 **acme.Challenge + return test{ + prov: wireProvisioner, + db: &acme.MockDB{ + MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { + switch count { + case 0: + ch.ID = "wireapp-device" + assert.Equals(t, ch.Type, acme.WIREDPOP01) + ch1 = &ch + default: + assert.FatalError(t, errors.New("test logic error")) + return errors.New("force") + } + count++ + assert.Equals(t, ch.AccountID, az.AccountID) + assert.Equals(t, ch.Token, az.Token) + assert.Equals(t, ch.Status, acme.StatusPending) + assert.Equals(t, ch.Value, az.Identifier.Value) + return nil + }, + MockCreateAuthorization: func(ctx context.Context, _az *acme.Authorization) error { + assert.Equals(t, _az.AccountID, az.AccountID) + assert.Equals(t, _az.Token, az.Token) + assert.Equals(t, _az.Status, acme.StatusPending) + assert.Equals(t, _az.Identifier, az.Identifier) + assert.Equals(t, _az.ExpiresAt, az.ExpiresAt) + _ = ch1 + // assert.Equals(t, _az.Challenges, []*acme.Challenge{*ch1}) + assert.Equals(t, _az.Wildcard, false) + return nil + }, + }, + az: az, + } + }, } for name, run := range tests { t.Run(name, func(t *testing.T) { - if name == "ok/permanent-identifier-enabled" { - println(1) - } tc := run(t) ctx := newBaseContext(context.Background(), tc.db) ctx = acme.NewProvisionerContext(ctx, tc.prov) - if err := newAuthorization(ctx, tc.az); err != nil { - if assert.NotNil(t, tc.err) { - var k *acme.Error - if assert.True(t, errors.As(err, &k)) { - assert.Equals(t, k.Type, tc.err.Type) - assert.Equals(t, k.Detail, tc.err.Detail) - assert.Equals(t, k.Status, tc.err.Status) - assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) - assert.Equals(t, k.Detail, tc.err.Detail) - } else { - assert.FatalError(t, errors.New("unexpected error type")) - } + err := newAuthorization(ctx, tc.az) + if tc.err != nil { + sassert.Error(t, err) + var k *acme.Error + if sassert.True(t, errors.As(err, &k)) { + sassert.Equal(t, tc.err.Type, k.Type) + sassert.Equal(t, tc.err.Detail, k.Detail) + sassert.Equal(t, tc.err.Status, k.Status) + sassert.EqualError(t, k.Err, tc.err.Error()) } - } else { - assert.Nil(t, tc.err) + return } + + sassert.NoError(t, err) }) } } @@ -1734,7 +1897,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= }, } }, - "ok/default-naf-nbf-wireapp-user": func(t *testing.T) test { + "ok/default-naf-nbf-wireapp": func(t *testing.T) test { acmeWireProv := newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wire.Options{ OIDC: &wire.OIDCOptions{ @@ -1764,7 +1927,8 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= acc := &acme.Account{ID: "accID"} nor := &NewOrderRequest{ Identifiers: []acme.Identifier{ - {Type: "wireapp-user", Value: `{"name": "Alice Smith", "handle": "wireapp://%40alice.smith.qa@example.com"}`}, + {Type: "wireapp-user", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, + {Type: "wireapp-device", Value: `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice_wire@wire.com"}`}, }, } b, err := json.Marshal(nor) @@ -1773,8 +1937,9 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( - ch1 **acme.Challenge - az1ID *string + ch1, ch2 **acme.Challenge + az1ID, az2ID *string + chCount, azCount = 0, 0 ) return test{ ctx: ctx, @@ -1783,127 +1948,49 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= ca: &mockCA{}, db: &acme.MockDB{ MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { - ch.ID = "wireapp-oidc" - assert.Equals(t, ch.Type, acme.WIREOIDC01) - ch1 = &ch + switch chCount { + case 0: + assert.Equals(t, ch.Type, acme.WIREOIDC01) + assert.Equals(t, ch.Value, `{"name": "Smith, Alice M (QA)", "domain": "example.com", "handle": "wireapp://%40alice_wire@wire.com"}`) + ch.ID = "wireapp-oidc" + ch1 = &ch + case 1: + assert.Equals(t, ch.Type, acme.WIREDPOP01) + assert.Equals(t, ch.Value, `{"name": "Smith, Alice M (QA)", "domain": "example.com", "client-id": "wireapp://lJGYPz0ZRq2kvc_XpdaDlA!ed416ce8ecdd9fad@example.com", "handle": "wireapp://%40alice_wire@wire.com"}`) + ch.ID = "wireapp-dpop" + ch2 = &ch + default: + require.Fail(t, "test logic error") + } + chCount++ assert.Equals(t, ch.AccountID, "accID") assert.NotEquals(t, ch.Token, "") assert.Equals(t, ch.Status, acme.StatusPending) - assert.Equals(t, ch.Value, `{"name": "Alice Smith", "handle": "wireapp://%40alice.smith.qa@example.com"}`) - return nil - }, - MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { - az.ID = "az1ID" - az1ID = &az.ID - assert.Equals(t, az.AccountID, "accID") - assert.NotEquals(t, az.Token, "") - assert.Equals(t, az.Status, acme.StatusPending) - assert.Equals(t, az.Identifier, nor.Identifiers[0]) - assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1}) - assert.Equals(t, az.Wildcard, false) - return nil - }, - MockCreateOrder: func(ctx context.Context, o *acme.Order) error { - o.ID = "ordID" - assert.Equals(t, o.AccountID, "accID") - assert.Equals(t, o.ProvisionerID, prov.GetID()) - assert.Equals(t, o.Status, acme.StatusPending) - assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) - return nil - }, - MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { - assert.Equals(t, prov.GetID(), provisionerID) - assert.Equals(t, "accID", accountID) - return nil, nil - }, - }, - vr: func(t *testing.T, o *acme.Order) { - now := clock.Now() - testBufferDur := 5 * time.Second - orderExpiry := now.Add(defaultOrderExpiry) - expNbf := now.Add(-defaultOrderBackdate) - expNaf := now.Add(prov.DefaultTLSCertDuration()) - assert.Equals(t, o.ID, "ordID") - assert.Equals(t, o.Status, acme.StatusPending) - assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) - assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) - assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) - assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) - assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf)) - assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry)) - assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry)) - }, - } - }, - "ok/default-naf-nbf-wireapp-device": func(t *testing.T) test { - acmeWireProv := newWireProvisionerWithOptions(t, &provisioner.Options{ - Wire: &wire.Options{ - OIDC: &wire.OIDCOptions{ - Provider: &wire.Provider{ - IssuerURL: "https://issuer.example.com", - AuthURL: "", - TokenURL: "", - JWKSURL: "", - UserInfoURL: "", - Algorithms: []string{"ES256"}, - }, - Config: &wire.Config{ - ClientID: "integration test", - SignatureAlgorithms: []string{"ES256"}, - SkipClientIDCheck: true, - SkipExpiryCheck: true, - SkipIssuerCheck: true, - InsecureSkipSignatureCheck: true, - Now: time.Now, - }, - }, - DPOP: &wire.DPOPOptions{ - SigningKey: []byte(fakeWireSigningKey), - }, - }, - }) - acc := &acme.Account{ID: "accID"} - nor := &NewOrderRequest{ - Identifiers: []acme.Identifier{ - {Type: "wireapp-device", Value: `{"client-id": "wireapp://user!client@domain"}`}, - }, - } - b, err := json.Marshal(nor) - assert.FatalError(t, err) - ctx := acme.NewProvisionerContext(context.Background(), acmeWireProv) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - var ( - ch1 **acme.Challenge - az1ID *string - ) - return test{ - ctx: ctx, - statusCode: 201, - nor: nor, - ca: &mockCA{}, - db: &acme.MockDB{ - MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { - ch.ID = "wireapp-dpop" - assert.Equals(t, ch.Type, acme.WIREDPOP01) - ch1 = &ch - assert.Equals(t, ch.AccountID, "accID") - assert.NotEquals(t, ch.Token, "") - assert.Equals(t, ch.Status, acme.StatusPending) - assert.Equals(t, ch.Value, `{"client-id": "wireapp://user!client@domain"}`) + _, _ = ch1, ch2 + return nil }, MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error { - az.ID = "az1ID" - az1ID = &az.ID + switch azCount { + case 0: + az.ID = "az1ID" + az1ID = &az.ID + assert.Equals(t, az.Identifier, nor.Identifiers[0]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1}) + case 1: + az.ID = "az2ID" + az2ID = &az.ID + assert.Equals(t, az.Identifier, nor.Identifiers[1]) + assert.Equals(t, az.Challenges, []*acme.Challenge{*ch2}) + default: + require.Fail(t, "test logic error") + } + azCount++ + assert.Equals(t, az.AccountID, "accID") assert.NotEquals(t, az.Token, "") assert.Equals(t, az.Status, acme.StatusPending) - assert.Equals(t, az.Identifier, nor.Identifiers[0]) - assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1}) assert.Equals(t, az.Wildcard, false) return nil }, @@ -1913,7 +2000,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= assert.Equals(t, o.ProvisionerID, prov.GetID()) assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) + assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) return nil }, MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) { @@ -1932,7 +2019,10 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Identifiers, nor.Identifiers) - assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)}) + assert.Equals(t, o.AuthorizationURLs, []string{ + fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName), + fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName), + }) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) diff --git a/acme/api/wire_integration_test.go b/acme/api/wire_integration_test.go index c59af0693..56e94cd26 100644 --- a/acme/api/wire_integration_test.go +++ b/acme/api/wire_integration_test.go @@ -25,6 +25,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner/wire" nosqlDB "github.com/smallstep/nosql" @@ -42,16 +43,23 @@ const ( ) func newWireProvisionerWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME { - p := newProvWithOptions(options) - a, ok := p.(*provisioner.ACME) - if !ok { - t.Fatal("not a valid ACME provisioner") - } - a.Challenges = []provisioner.ACMEChallenge{ - provisioner.WIREOIDC_01, - provisioner.WIREDPOP_01, + t.Helper() + prov := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + Options: options, + Challenges: []provisioner.ACMEChallenge{ + provisioner.WIREOIDC_01, + provisioner.WIREDPOP_01, + }, } - return a + + err := prov.Init(provisioner.Config{ + Claims: config.GlobalProvisionerClaims, + }) + require.NoError(t, err) + + return prov } // TODO(hs): replace with test CA server + acmez based test client for diff --git a/acme/challenge.go b/acme/challenge.go index 0a80f3ed0..465125e91 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -362,29 +362,31 @@ func wireOIDC01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSO if !ok { return NewErrorISE("missing provisioner") } + wireOptions, err := prov.GetOptions().GetWireOptions() + if err != nil { + return WrapErrorISE(err, "failed getting Wire options") + } linker, ok := LinkerFromContext(ctx) if !ok { return NewErrorISE("missing linker") } var oidcPayload wireOidcPayload - err := json.Unmarshal(payload, &oidcPayload) - if err != nil { + if err := json.Unmarshal(payload, &oidcPayload); err != nil { return WrapError(ErrorMalformedType, err, "error unmarshalling Wire OIDC challenge payload") } - wireID, err := wire.ParseUserID([]byte(ch.Value)) + wireID, err := wire.ParseUserID(ch.Value) if err != nil { return WrapErrorISE(err, "error unmarshalling challenge data") } - wireOptions, err := prov.GetOptions().GetWireOptions() + oidcOptions := wireOptions.GetOIDCOptions() + verifier, err := oidcOptions.GetVerifier(ctx) if err != nil { - return WrapErrorISE(err, "failed getting Wire options") + return WrapErrorISE(err, "no OIDC verifier available") } - oidcOptions := wireOptions.GetOIDCOptions() - verifier := oidcOptions.GetProvider(ctx).Verifier(oidcOptions.GetConfig()) idToken, err := verifier.Verify(ctx, oidcPayload.IDToken) if err != nil { return storeError(ctx, db, ch, true, WrapError(ErrorRejectedIdentifierType, err, @@ -490,6 +492,10 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j if !ok { return NewErrorISE("missing provisioner") } + wireOptions, err := prov.GetOptions().GetWireOptions() + if err != nil { + return WrapErrorISE(err, "failed getting Wire options") + } linker, ok := LinkerFromContext(ctx) if !ok { return NewErrorISE("missing linker") @@ -500,7 +506,7 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j return WrapError(ErrorMalformedType, err, "error unmarshalling Wire DPoP challenge payload") } - wireID, err := wire.ParseDeviceID([]byte(ch.Value)) + wireID, err := wire.ParseDeviceID(ch.Value) if err != nil { return WrapErrorISE(err, "error unmarshalling challenge data") } @@ -510,11 +516,6 @@ func wireDPOP01Validate(ctx context.Context, ch *Challenge, db DB, accountJWK *j return WrapErrorISE(err, "error parsing device id") } - wireOptions, err := prov.GetOptions().GetWireOptions() - if err != nil { - return WrapErrorISE(err, "failed getting Wire options") - } - dpopOptions := wireOptions.GetDPOPOptions() issuer, err := dpopOptions.EvaluateTarget(clientID.DeviceID) if err != nil { @@ -721,6 +722,14 @@ func parseAndVerifyWireAccessToken(v wireVerifyParams) (*wireAccessToken, *wireD return nil, nil, fmt.Errorf("invalid Wire client handle %q", handle) } + name, ok := dpopToken["name"].(string) + if !ok { + return nil, nil, fmt.Errorf("invalid display name in Wire DPoP token") + } + if name == "" || name != v.wireID.Name { + return nil, nil, fmt.Errorf("invalid Wire client display name %q", name) + } + return &accessToken, &dpopToken, nil } diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 35d943765..4f09535dc 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -1008,6 +1008,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` + Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", @@ -1017,6 +1018,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", + Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) diff --git a/acme/challenge_wire_test.go b/acme/challenge_wire_test.go index b7881fa9c..1ac381ce3 100644 --- a/acme/challenge_wire_test.go +++ b/acme/challenge_wire_test.go @@ -47,7 +47,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/no-linker": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) return test{ ctx: ctx, expectedErr: &Error{ @@ -59,7 +77,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/unmarshal": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, @@ -82,7 +118,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/wire-parse-id": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, @@ -105,7 +159,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/wire-parse-client-id": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) valueBytes, err := json.Marshal(struct { Name string `json:"name,omitempty"` @@ -139,41 +211,6 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= }, } }, - "fail/no-wire-options": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) - ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) - valueBytes, err := json.Marshal(struct { - Name string `json:"name,omitempty"` - Domain string `json:"domain,omitempty"` - ClientID string `json:"client-id,omitempty"` - Handle string `json:"handle,omitempty"` - }{ - Name: "Alice Smith", - Domain: "wire.com", - ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", - Handle: "wireapp://%40alice_wire@wire.com", - }) - require.NoError(t, err) - return test{ - ctx: ctx, - payload: []byte("{}"), - ch: &Challenge{ - ID: "chID", - AuthorizationID: "azID", - AccountID: "accID", - Token: "token", - Type: "wire-dpop-01", - Status: StatusPending, - Value: string(valueBytes), - }, - expectedErr: &Error{ - Type: "urn:ietf:params:acme:error:serverInternal", - Detail: "The server experienced an internal error", - Status: 500, - Err: errors.New(`failed getting Wire options: no Wire options available`), - }, - } - }, "fail/parse-and-verify": func(t *testing.T) test { ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ Wire: &wireprovisioner.Options{ @@ -269,6 +306,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` + Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", @@ -278,6 +316,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", + Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) @@ -413,6 +452,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` + Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", @@ -422,6 +462,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", + Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) @@ -561,6 +602,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` + Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", @@ -570,6 +612,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", + Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) @@ -709,6 +752,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` + Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", @@ -718,6 +762,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", + Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) @@ -864,6 +909,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle string `json:"handle,omitempty"` Nonce string `json:"nonce,omitempty"` HTU string `json:"htu,omitempty"` + Name string `json:"name,omitempty"` }{ Claims: jose.Claims{ Subject: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", @@ -873,6 +919,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= Handle: "wireapp://%40alice_wire@wire.com", Nonce: "nonce", HTU: "http://issuer.example.com", + Name: "Alice Smith", }) require.NoError(t, err) dpop, err := dpopSigner.Sign(dpopBytes) @@ -1037,7 +1084,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/no-linker": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) return test{ ctx: ctx, expectedErr: &Error{ @@ -1049,7 +1114,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/unmarshal": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, @@ -1078,7 +1161,25 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= } }, "fail/wire-parse-id": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) + ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{ + Wire: &wireprovisioner.Options{ + OIDC: &wireprovisioner.OIDCOptions{ + Provider: &wireprovisioner.Provider{ + IssuerURL: "https://issuer.example.com", + Algorithms: []string{"ES256"}, + }, + Config: &wireprovisioner.Config{ + ClientID: "test", + SignatureAlgorithms: []string{"ES256"}, + Now: time.Now, + }, + TransformTemplate: "", + }, + DPOP: &wireprovisioner.DPOPOptions{ + SigningKey: []byte(fakeKey), + }, + }, + })) ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) return test{ ctx: ctx, @@ -1100,41 +1201,6 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= }, } }, - "fail/no-wire-options": func(t *testing.T) test { - ctx := NewProvisionerContext(context.Background(), newWireProvisionerWithOptions(t, &provisioner.Options{})) - ctx = NewLinkerContext(ctx, NewLinker("ca.example.com", "acme")) - valueBytes, err := json.Marshal(struct { - Name string `json:"name,omitempty"` - Domain string `json:"domain,omitempty"` - ClientID string `json:"client-id,omitempty"` - Handle string `json:"handle,omitempty"` - }{ - Name: "Alice Smith", - Domain: "wire.com", - ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", - Handle: "wireapp://%40alice_wire@wire.com", - }) - require.NoError(t, err) - return test{ - ctx: ctx, - payload: []byte("{}"), - ch: &Challenge{ - ID: "chID", - AuthorizationID: "azID", - AccountID: "accID", - Token: "token", - Type: "wire-oidc-01", - Status: StatusPending, - Value: string(valueBytes), - }, - expectedErr: &Error{ - Type: "urn:ietf:params:acme:error:serverInternal", - Detail: "The server experienced an internal error", - Status: 500, - Err: errors.New(`failed getting Wire options: no Wire options available`), - }, - } - }, "fail/verify": func(t *testing.T) test { jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token") signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -2122,8 +2188,8 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= idTokenString := `eyJhbGciOiJSUzI1NiIsImtpZCI6IjZhNDZlYzQ3YTQzYWI1ZTc4NzU3MzM5NWY1MGY4ZGQ5MWI2OTM5MzcifQ.eyJpc3MiOiJodHRwOi8vZGV4OjE1ODE4L2RleCIsInN1YiI6IkNqcDNhWEpsWVhCd09pOHZTMmh0VjBOTFpFTlRXakoyT1dWTWFHRk9XVlp6WnlFeU5UZzFNVEpoT0RRek5qTXhaV1V6UUhkcGNtVXVZMjl0RWdSc1pHRnciLCJhdWQiOiJ3aXJlYXBwIiwiZXhwIjoxNzA1MDkxNTYyLCJpYXQiOjE3MDUwMDUxNjIsIm5vbmNlIjoib0VjUzBRQUNXLVIyZWkxS09wUmZ2QSIsImF0X2hhc2giOiJoYzk0NmFwS25FeEV5TDVlSzJZMzdRIiwiY19oYXNoIjoidmRubFp2V1d1bVd1Z2NYR1JpOU5FUSIsIm5hbWUiOiJ3aXJlYXBwOi8vJTQwYWxpY2Vfd2lyZUB3aXJlLmNvbSIsInByZWZlcnJlZF91c2VybmFtZSI6IkFsaWNlIFNtaXRoIn0.aEBhWJugBJ9J_0L_4odUCg8SR8HMXVjd__X8uZRo42BSJQQO7-wdpy0jU3S4FOX9fQKr68wD61gS_QsnhfiT7w9U36mLpxaYlNVDCYfpa-gklVFit_0mjUOukXajTLK6H527TGiSss8z22utc40ckS1SbZa2BzKu3yOcqnFHUQwQc5sLYfpRABTB6WBoYFtnWDzdpyWJDaOzz7lfKYv2JBnf9vV8u8SYm-6gNKgtiQ3UUnjhIVUjdfHet2BMvmV2ooZ8V441RULCzKKG_sWZba-D_k_TOnSholGobtUOcKHlmVlmfUe8v7kuyBdhbPcembfgViaNldLQGKZjZfgvLg` ctx := context.Background() o := opts.GetOIDCOptions() - c := o.GetConfig() - verifier := o.GetProvider(ctx).Verifier(c) + verifier, err := o.GetVerifier(ctx) + require.NoError(t, err) idToken, err := verifier.Verify(ctx, idTokenString) require.NoError(t, err) diff --git a/acme/order.go b/acme/order.go index 974bac5f7..1175bc385 100644 --- a/acme/order.go +++ b/acme/order.go @@ -340,7 +340,7 @@ func createWireSubject(o *Order, csr *x509.CertificateRequest) (subject x509util for _, identifier := range o.Identifiers { switch identifier.Type { case WireUser: - wireID, err := wire.ParseUserID([]byte(identifier.Value)) + wireID, err := wire.ParseUserID(identifier.Value) if err != nil { return subject, NewErrorISE("unmarshal wireID: %s", err) } @@ -406,7 +406,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ orderPIDs[indexPID] = n.Value indexPID++ case WireUser: - wireID, err := wire.ParseUserID([]byte(n.Value)) + wireID, err := wire.ParseUserID(n.Value) if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } @@ -417,7 +417,7 @@ func (o *Order) sans(csr *x509.CertificateRequest) ([]x509util.SubjectAlternativ tmpOrderURIs[indexURI] = handle indexURI++ case WireDevice: - wireID, err := wire.ParseDeviceID([]byte(n.Value)) + wireID, err := wire.ParseDeviceID(n.Value) if err != nil { return sans, NewErrorISE("unsupported identifier value in order: %s", n.Value) } diff --git a/acme/wire/id.go b/acme/wire/id.go index 3662cc701..d1abcedbb 100644 --- a/acme/wire/id.go +++ b/acme/wire/id.go @@ -2,6 +2,7 @@ package wire import ( "encoding/json" + "errors" "fmt" "strings" @@ -21,13 +22,39 @@ type DeviceID struct { Handle string `json:"handle,omitempty"` } -func ParseUserID(data []byte) (id UserID, err error) { - err = json.Unmarshal(data, &id) +func ParseUserID(value string) (id UserID, err error) { + if err = json.Unmarshal([]byte(value), &id); err != nil { + return + } + + switch { + case id.Handle == "": + err = errors.New("handle must not be empty") + case id.Name == "": + err = errors.New("name must not be empty") + case id.Domain == "": + err = errors.New("domain must not be empty") + } + return } -func ParseDeviceID(data []byte) (id DeviceID, err error) { - err = json.Unmarshal(data, &id) +func ParseDeviceID(value string) (id DeviceID, err error) { + if err = json.Unmarshal([]byte(value), &id); err != nil { + return + } + + switch { + case id.Handle == "": + err = errors.New("handle must not be empty") + case id.Name == "": + err = errors.New("name must not be empty") + case id.Domain == "": + err = errors.New("domain must not be empty") + case id.ClientID == "": + err = errors.New("client-id must not be empty") + } + return } diff --git a/acme/wire/id_test.go b/acme/wire/id_test.go index 4c008e462..3cf114b7d 100644 --- a/acme/wire/id_test.go +++ b/acme/wire/id_test.go @@ -9,19 +9,27 @@ import ( func TestParseUserID(t *testing.T) { ok := `{"name": "Alice Smith", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` + failJSON := `{"name": }` + emptyHandle := `{"name": "Alice Smith", "domain": "wire.com", "handle": ""}` + emptyName := `{"name": "", "domain": "wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` + emptyDomain := `{"name": "Alice Smith", "domain": "", "handle": "wireapp://%40alice_wire@wire.com"}` tests := []struct { - name string - data []byte - wantWireID UserID - expectedErr error + name string + value string + wantWireID UserID + wantErr bool }{ - {name: "ok", data: []byte(ok), wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, + {name: "ok", value: ok, wantWireID: UserID{Name: "Alice Smith", Domain: "wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, + {name: "fail/json", value: failJSON, wantErr: true}, + {name: "fail/empty-handle", value: emptyHandle, wantErr: true}, + {name: "fail/empty-name", value: emptyName, wantErr: true}, + {name: "fail/empty-domain", value: emptyDomain, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotWireID, err := ParseUserID(tt.data) - if tt.expectedErr != nil { - assert.EqualError(t, err, tt.expectedErr.Error()) + gotWireID, err := ParseUserID(tt.value) + if tt.wantErr { + assert.Error(t, err) return } @@ -33,19 +41,29 @@ func TestParseUserID(t *testing.T) { func TestParseDeviceID(t *testing.T) { ok := `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` + failJSON := `{"name": }` + emptyHandle := `{"name": "device", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": ""}` + emptyName := `{"name": "", "domain": "wire.com", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` + emptyDomain := `{"name": "device", "domain": "", "client-id": "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", "handle": "wireapp://%40alice_wire@wire.com"}` + emptyClientID := `{"name": "device", "domain": "wire.com", "client-id": "", "handle": "wireapp://%40alice_wire@wire.com"}` tests := []struct { - name string - data []byte - wantWireID DeviceID - expectedErr error + name string + value string + wantWireID DeviceID + wantErr bool }{ - {name: "ok", data: []byte(ok), wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, + {name: "ok", value: ok, wantWireID: DeviceID{Name: "device", Domain: "wire.com", ClientID: "wireapp://CzbfFjDOQrenCbDxVmgnFw!594930e9d50bb175@wire.com", Handle: "wireapp://%40alice_wire@wire.com"}}, + {name: "fail/json", value: failJSON, wantErr: true}, + {name: "fail/empty-handle", value: emptyHandle, wantErr: true}, + {name: "fail/empty-name", value: emptyName, wantErr: true}, + {name: "fail/empty-domain", value: emptyDomain, wantErr: true}, + {name: "fail/empty-client-id", value: emptyClientID, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotWireID, err := ParseDeviceID(tt.data) - if tt.expectedErr != nil { - assert.EqualError(t, err, tt.expectedErr.Error()) + gotWireID, err := ParseDeviceID(tt.value) + if tt.wantErr { + assert.Error(t, err) return } diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index f338a78ae..3b7fa654d 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -107,7 +107,8 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` // Challenges contains the enabled challenges for this provisioner. If this // value is not set the default http-01, dns-01 and tls-alpn-01 challenges - // will be enabled, device-attest-01 will be disabled. + // will be enabled, device-attest-01, wire-oidc-01 and wire-dpop-01 will be + // disabled. Challenges []ACMEChallenge `json:"challenges,omitempty"` // AttestationFormats contains the enabled attestation formats for this // provisioner. If this value is not set the default apple, step and tpm @@ -211,10 +212,50 @@ func (p *ACME) Init(config Config) (err error) { } } + if err := p.initializeWireOptions(); err != nil { + return fmt.Errorf("failed initializing Wire options: %w", err) + } + p.ctl, err = NewController(p, p.Claims, config, p.Options) return } +// initializeWireOptions initializes the options for the ACME Wire +// integration. It'll return early if no Wire challenge types are +// enabled. +func (p *ACME) initializeWireOptions() error { + hasWireChallenges := false + for _, c := range p.Challenges { + if c == WIREOIDC_01 || c == WIREDPOP_01 { + hasWireChallenges = true + break + } + } + if !hasWireChallenges { + return nil + } + + w, err := p.GetOptions().GetWireOptions() + if err != nil { + return fmt.Errorf("failed getting Wire options: %w", err) + } + + if err := w.Validate(); err != nil { + return fmt.Errorf("failed validating Wire options: %w", err) + } + + // at this point the Wire options have been validated, and (mostly) + // initialized. Remote keys will be loaded upon the first verification, + // currently. + // TODO(hs): can/should we "prime" the underlying remote keyset, to verify + // auto discovery works as expected? Because of the current way provisioners + // are initialized, doing that as part of the initialization isn't the best + // time to do it, because it could result in operations not resulting in the + // expected result in all cases. + + return nil +} + // ACMEIdentifierType encodes ACME Identifier types type ACMEIdentifierType string @@ -254,13 +295,13 @@ func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdenti err = x509Policy.IsDNSAllowed(identifier.Value) case WireUser: var wireID wire.UserID - if wireID, err = wire.ParseUserID([]byte(identifier.Value)); err != nil { + if wireID, err = wire.ParseUserID(identifier.Value); err != nil { return fmt.Errorf("failed parsing Wire SANs: %w", err) } err = x509Policy.AreSANsAllowed([]string{wireID.Handle}) case WireDevice: var wireID wire.DeviceID - if wireID, err = wire.ParseDeviceID([]byte(identifier.Value)); err != nil { + if wireID, err = wire.ParseDeviceID(identifier.Value); err != nil { return fmt.Errorf("failed parsing Wire SANs: %w", err) } err = x509Policy.AreSANsAllowed([]string{wireID.ClientID}) diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 94684ce19..96f4bd8b3 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -1,6 +1,3 @@ -//go:build !go1.18 -// +build !go1.18 - package provisioner import ( @@ -14,8 +11,10 @@ import ( "testing" "time" - "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner/wire" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestACMEChallenge_Validate(t *testing.T) { @@ -28,14 +27,20 @@ func TestACMEChallenge_Validate(t *testing.T) { {"dns-01", DNS_01, false}, {"tls-alpn-01", TLS_ALPN_01, false}, {"device-attest-01", DEVICE_ATTEST_01, false}, + {"wire-oidc-01", DEVICE_ATTEST_01, false}, + {"wire-dpop-01", DEVICE_ATTEST_01, false}, {"uppercase", "HTTP-01", false}, {"fail", "http-02", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.c.Validate(); (err != nil) != tt.wantErr { - t.Errorf("ACMEChallenge.Validate() error = %v, wantErr %v", err, tt.wantErr) + err := tt.c.Validate() + if tt.wantErr { + assert.Error(t, err) + return } + + assert.NoError(t, err) }) } } @@ -54,26 +59,24 @@ func TestACMEAttestationFormat_Validate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.f.Validate(); (err != nil) != tt.wantErr { - t.Errorf("ACMEAttestationFormat.Validate() error = %v, wantErr %v", err, tt.wantErr) + err := tt.f.Validate() + if tt.wantErr { + assert.Error(t, err) + return } + + assert.NoError(t, err) }) } } func TestACME_Getters(t *testing.T) { p, err := generateACME() - assert.FatalError(t, err) - id := "acme/" + p.Name - if got := p.GetID(); got != id { - t.Errorf("ACME.GetID() = %v, want %v", got, id) - } - if got := p.GetName(); got != p.Name { - t.Errorf("ACME.GetName() = %v, want %v", got, p.Name) - } - if got := p.GetType(); got != TypeACME { - t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME) - } + require.NoError(t, err) + id := "acme/test@acme-provisioner.com" + assert.Equal(t, id, p.GetID()) + assert.Equal(t, "test@acme-provisioner.com", p.GetName()) + assert.Equal(t, TypeACME, p.GetType()) kid, key, ok := p.GetEncryptedKey() if kid != "" || key != "" || ok == true { t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", @@ -83,26 +86,25 @@ func TestACME_Getters(t *testing.T) { func TestACME_Init(t *testing.T) { appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + fakeWireDPoPKey := []byte(`-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= +-----END PUBLIC KEY-----`) type ProvisionerValidateTest struct { p *ACME err error } tests := map[string]func(*testing.T) ProvisionerValidateTest{ - "fail-empty": func(t *testing.T) ProvisionerValidateTest { + "fail/empty": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{}, err: errors.New("provisioner type cannot be empty"), } }, - "fail-empty-name": func(t *testing.T) ProvisionerValidateTest { + "fail/empty-name": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Type: "ACME", @@ -110,60 +112,119 @@ func TestACME_Init(t *testing.T) { err: errors.New("provisioner name cannot be empty"), } }, - "fail-empty-type": func(t *testing.T) ProvisionerValidateTest { + "fail/empty-type": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{Name: "foo"}, err: errors.New("provisioner type cannot be empty"), } }, - "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { + "fail/bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &ACME{Name: "foo", Type: "bar", Claims: &Claims{DefaultTLSDur: &Duration{0}}}, + p: &ACME{Name: "foo", Type: "ACME", Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, - "fail-bad-challenge": func(t *testing.T) ProvisionerValidateTest { + "fail/bad-challenge": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &ACME{Name: "foo", Type: "bar", Challenges: []ACMEChallenge{HTTP_01, "zar"}}, + p: &ACME{Name: "foo", Type: "ACME", Challenges: []ACMEChallenge{HTTP_01, "zar"}}, err: errors.New("acme challenge \"zar\" is not supported"), } }, - "fail-bad-attestation-format": func(t *testing.T) ProvisionerValidateTest { + "fail/bad-attestation-format": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &ACME{Name: "foo", Type: "bar", AttestationFormats: []ACMEAttestationFormat{APPLE, "zar"}}, + p: &ACME{Name: "foo", Type: "ACME", AttestationFormats: []ACMEAttestationFormat{APPLE, "zar"}}, err: errors.New("acme attestation format \"zar\" is not supported"), } }, - "fail-parse-attestation-roots": func(t *testing.T) ProvisionerValidateTest { + "fail/parse-attestation-roots": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &ACME{Name: "foo", Type: "bar", AttestationRoots: []byte("-----BEGIN CERTIFICATE-----\nZm9v\n-----END CERTIFICATE-----")}, + p: &ACME{Name: "foo", Type: "ACME", AttestationRoots: []byte("-----BEGIN CERTIFICATE-----\nZm9v\n-----END CERTIFICATE-----")}, err: errors.New("error parsing attestationRoots: malformed certificate"), } }, - "fail-empty-attestation-roots": func(t *testing.T) ProvisionerValidateTest { + "fail/empty-attestation-roots": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &ACME{Name: "foo", Type: "bar", AttestationRoots: []byte("\n")}, + p: &ACME{Name: "foo", Type: "ACME", AttestationRoots: []byte("\n")}, err: errors.New("error parsing attestationRoots: no certificates found"), } }, + "fail/wire-missing-options": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{ + Name: "foo", + Type: "ACME", + Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, + }, + err: errors.New("failed initializing Wire options: failed getting Wire options: no options available"), + } + }, + "fail/wire-missing-wire-options": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{ + Name: "foo", + Type: "ACME", + Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, + Options: &Options{}, + }, + err: errors.New("failed initializing Wire options: failed getting Wire options: no Wire options available"), + } + }, + "fail/wire-validate-options": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{ + Name: "foo", + Type: "ACME", + Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, + Options: &Options{ + Wire: &wire.Options{ + OIDC: &wire.OIDCOptions{}, + DPOP: &wire.DPOPOptions{ + SigningKey: fakeWireDPoPKey, + }, + }, + }, + }, + err: errors.New("failed initializing Wire options: failed validating Wire options: failed initializing OIDC options: provider not set"), + } + }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &ACME{Name: "foo", Type: "bar"}, + p: &ACME{Name: "foo", Type: "ACME"}, } }, - "ok attestation": func(t *testing.T) ProvisionerValidateTest { + "ok/attestation": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &ACME{ Name: "foo", - Type: "bar", + Type: "ACME", Challenges: []ACMEChallenge{DNS_01, DEVICE_ATTEST_01}, AttestationFormats: []ACMEAttestationFormat{APPLE, STEP}, AttestationRoots: bytes.Join([][]byte{appleCA, yubicoCA}, []byte("\n")), }, } }, + "ok/wire": func(t *testing.T) ProvisionerValidateTest { + return ProvisionerValidateTest{ + p: &ACME{ + Name: "foo", + Type: "ACME", + Challenges: []ACMEChallenge{WIREOIDC_01, WIREDPOP_01}, + Options: &Options{ + Wire: &wire.Options{ + OIDC: &wire.OIDCOptions{ + Provider: &wire.Provider{ + IssuerURL: "https://issuer.example.com", + }, + }, + DPOP: &wire.DPOPOptions{ + SigningKey: fakeWireDPoPKey, + }, + }, + }, + }, + } + }, } - config := Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, @@ -173,13 +234,12 @@ func TestACME_Init(t *testing.T) { tc := get(t) t.Log(string(tc.p.AttestationRoots)) err := tc.p.Init(config) - if err != nil { - if assert.NotNil(t, tc.err) { - assert.Equals(t, tc.err.Error(), err.Error()) - } - } else { - assert.Nil(t, tc.err) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + return } + + assert.NoError(t, err) }) } } @@ -195,12 +255,12 @@ func TestACME_AuthorizeRenew(t *testing.T) { tests := map[string]func(*testing.T) test{ "fail/renew-disabled": func(t *testing.T) test { p, err := generateACME() - assert.FatalError(t, err) + require.NoError(t, err) // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, cert: &x509.Certificate{ @@ -213,7 +273,7 @@ func TestACME_AuthorizeRenew(t *testing.T) { }, "ok": func(t *testing.T) test { p, err := generateACME() - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, cert: &x509.Certificate{ @@ -226,16 +286,19 @@ func TestACME_AuthorizeRenew(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(render.StatusCodedError) - assert.Fatal(t, ok, "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tc.code) - if assert.NotNil(t, tc.err) { - assert.HasPrefix(t, err.Error(), tc.err.Error()) + err := tc.p.AuthorizeRenew(context.Background(), tc.cert) + if tc.err != nil { + if assert.Implements(t, (*render.StatusCodedError)(nil), err) { + var sc render.StatusCodedError + if errors.As(err, &sc) { + assert.Equal(t, tc.code, sc.StatusCode()) + } } - } else { - assert.Nil(t, tc.err) + assert.EqualError(t, err, tc.err.Error()) + return } + + assert.NoError(t, err) }) } } @@ -250,7 +313,7 @@ func TestACME_AuthorizeSign(t *testing.T) { tests := map[string]func(*testing.T) test{ "ok": func(t *testing.T) test { p, err := generateACME() - assert.FatalError(t, err) + require.NoError(t, err) return test{ p: p, token: "foo", @@ -260,39 +323,43 @@ func TestACME_AuthorizeSign(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { - if assert.NotNil(t, tc.err) { - sc, ok := err.(render.StatusCodedError) - assert.Fatal(t, ok, "error does not implement StatusCodedError interface") - assert.Equals(t, sc.StatusCode(), tc.code) - assert.HasPrefix(t, err.Error(), tc.err.Error()) + opts, err := tc.p.AuthorizeSign(context.Background(), tc.token) + if tc.err != nil { + if assert.Implements(t, (*render.StatusCodedError)(nil), err) { + var sc render.StatusCodedError + if errors.As(err, &sc) { + assert.Equal(t, tc.code, sc.StatusCode()) + } } - } else { - if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { - assert.Equals(t, 8, len(opts)) // number of SignOptions returned - for _, o := range opts { - switch v := o.(type) { - case *ACME: - case *provisionerExtensionOption: - assert.Equals(t, v.Type, TypeACME) - assert.Equals(t, v.Name, tc.p.GetName()) - assert.Equals(t, v.CredentialID, "") - assert.Len(t, 0, v.KeyValuePairs) - case *forceCNOption: - assert.Equals(t, v.ForceCN, tc.p.ForceCN) - case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) - case defaultPublicKeyValidator: - case *validityValidator: - assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) - case *x509NamePolicyValidator: - assert.Equals(t, nil, v.policyEngine) - case *WebhookController: - assert.Len(t, 0, v.webhooks) - default: - assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) - } + assert.EqualError(t, err, tc.err.Error()) + return + } + + assert.NoError(t, err) + if assert.NotNil(t, opts) { + assert.Len(t, opts, 8) // number of SignOptions returned + for _, o := range opts { + switch v := o.(type) { + case *ACME: + case *provisionerExtensionOption: + assert.Equal(t, v.Type, TypeACME) + assert.Equal(t, v.Name, tc.p.GetName()) + assert.Equal(t, v.CredentialID, "") + assert.Len(t, v.KeyValuePairs, 0) + case *forceCNOption: + assert.Equal(t, v.ForceCN, tc.p.ForceCN) + case profileDefaultDuration: + assert.Equal(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) + case defaultPublicKeyValidator: + case *validityValidator: + assert.Equal(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equal(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) + case *x509NamePolicyValidator: + assert.Equal(t, nil, v.policyEngine) + case *WebhookController: + assert.Len(t, v.webhooks, 0) + default: + require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -323,10 +390,14 @@ func TestACME_IsChallengeEnabled(t *testing.T) { {"ok dns-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, DNS_01}, true}, {"ok tls-alpn-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01"}}, args{ctx, TLS_ALPN_01}, true}, {"ok device-attest-01 enabled", fields{[]ACMEChallenge{"device-attest-01", "dns-01"}}, args{ctx, DEVICE_ATTEST_01}, true}, + {"ok wire-oidc-01 enabled", fields{[]ACMEChallenge{"wire-oidc-01"}}, args{ctx, WIREOIDC_01}, true}, + {"ok wire-dpop-01 enabled", fields{[]ACMEChallenge{"wire-dpop-01"}}, args{ctx, WIREDPOP_01}, true}, {"fail http-01", fields{[]ACMEChallenge{"dns-01"}}, args{ctx, "http-01"}, false}, {"fail dns-01", fields{[]ACMEChallenge{"http-01", "tls-alpn-01"}}, args{ctx, "dns-01"}, false}, {"fail tls-alpn-01", fields{[]ACMEChallenge{"http-01", "dns-01", "device-attest-01"}}, args{ctx, "tls-alpn-01"}, false}, {"fail device-attest-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "device-attest-01"}, false}, + {"fail wire-oidc-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-oidc-01"}, false}, + {"fail wire-dpop-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-dpop-01"}, false}, {"fail unknown", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01", "device-attest-01"}}, args{ctx, "unknown"}, false}, } for _, tt := range tests { @@ -334,9 +405,8 @@ func TestACME_IsChallengeEnabled(t *testing.T) { p := &ACME{ Challenges: tt.fields.Challenges, } - if got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge); got != tt.want { - t.Errorf("ACME.AuthorizeChallenge() = %v, want %v", got, tt.want) - } + got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge) + assert.Equal(t, tt.want, got) }) } } @@ -370,9 +440,8 @@ func TestACME_IsAttestationFormatEnabled(t *testing.T) { p := &ACME{ AttestationFormats: tt.fields.AttestationFormats, } - if got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format); got != tt.want { - t.Errorf("ACME.IsAttestationFormatEnabled() = %v, want %v", got, tt.want) - } + got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format) + assert.Equal(t, tt.want, got) }) } } diff --git a/authority/provisioner/options.go b/authority/provisioner/options.go index 135327349..ec7780819 100644 --- a/authority/provisioner/options.go +++ b/authority/provisioner/options.go @@ -2,7 +2,6 @@ package provisioner import ( "encoding/json" - "fmt" "strings" "github.com/pkg/errors" @@ -54,7 +53,8 @@ func (o *Options) GetSSHOptions() *SSHOptions { return o.SSH } -// GetWireOptions returns the SSH options. +// GetWireOptions returns the Wire options if available. It +// returns an error if they're not available. func (o *Options) GetWireOptions() (*wire.Options, error) { if o == nil { return nil, errors.New("no options available") @@ -62,9 +62,6 @@ func (o *Options) GetWireOptions() (*wire.Options, error) { if o.Wire == nil { return nil, errors.New("no Wire options available") } - if err := o.Wire.Validate(); err != nil { - return nil, fmt.Errorf("failed validating Wire options: %w", err) - } return o.Wire, nil } diff --git a/authority/provisioner/wire/dpop_options.go b/authority/provisioner/wire/dpop_options.go index 721eab014..010cd5ee8 100644 --- a/authority/provisioner/wire/dpop_options.go +++ b/authority/provisioner/wire/dpop_options.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "crypto" + "errors" "fmt" "text/template" @@ -24,9 +25,12 @@ func (o *DPOPOptions) GetSigningKey() crypto.PublicKey { } func (o *DPOPOptions) EvaluateTarget(deviceID string) (string, error) { + if deviceID == "" { + return "", errors.New("deviceID must not be empty") + } buf := new(bytes.Buffer) if err := o.target.Execute(buf, struct{ DeviceID string }{DeviceID: deviceID}); err != nil { - return "", fmt.Errorf("failed executing dpop template: %w", err) + return "", fmt.Errorf("failed executing DPoP template: %w", err) } return buf.String(), nil } diff --git a/authority/provisioner/wire/dpop_options_test.go b/authority/provisioner/wire/dpop_options_test.go new file mode 100644 index 000000000..68aeb7cf7 --- /dev/null +++ b/authority/provisioner/wire/dpop_options_test.go @@ -0,0 +1,58 @@ +package wire + +import ( + "errors" + "testing" + "text/template" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDPOPOptions_EvaluateTarget(t *testing.T) { + tu := "http://wire.com:15958/clients/{{.DeviceID}}/access-token" + target, err := template.New("DeviceID").Parse(tu) + require.NoError(t, err) + fail := "https:/wire.com:15958/clients/{{.DeviceId}}/access-token" + failTarget, err := template.New("DeviceID").Parse(fail) + require.NoError(t, err) + type fields struct { + target *template.Template + } + type args struct { + deviceID string + } + tests := []struct { + name string + fields fields + args args + want string + expectedErr error + }{ + { + name: "ok", fields: fields{target: target}, args: args{deviceID: "deviceID"}, want: "http://wire.com:15958/clients/deviceID/access-token", + }, + { + name: "fail/empty", fields: fields{target: target}, args: args{deviceID: ""}, expectedErr: errors.New("deviceID must not be empty"), + }, + { + name: "fail/template", fields: fields{target: failTarget}, args: args{deviceID: "bla"}, expectedErr: errors.New(`failed executing DPoP template: template: DeviceID:1:32: executing "DeviceID" at <.DeviceId>: can't evaluate field DeviceId in type struct { DeviceID string }`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &DPOPOptions{ + target: tt.fields.target, + } + got, err := o.EvaluateTarget(tt.args.deviceID) + if tt.expectedErr != nil { + assert.EqualError(t, err, tt.expectedErr.Error()) + assert.Empty(t, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/authority/provisioner/wire/oidc_options.go b/authority/provisioner/wire/oidc_options.go index 5bbcbc7a0..e9139caa1 100644 --- a/authority/provisioner/wire/oidc_options.go +++ b/authority/provisioner/wire/oidc_options.go @@ -15,12 +15,13 @@ import ( ) type Provider struct { - IssuerURL string `json:"issuer,omitempty"` - AuthURL string `json:"authorization_endpoint,omitempty"` - TokenURL string `json:"token_endpoint,omitempty"` - JWKSURL string `json:"jwks_uri,omitempty"` - UserInfoURL string `json:"userinfo_endpoint,omitempty"` - Algorithms []string `json:"id_token_signing_alg_values_supported,omitempty"` + DiscoveryBaseURL string `json:"discoveryBaseUrl,omitempty"` + IssuerURL string `json:"issuerUrl,omitempty"` + AuthURL string `json:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty"` + JWKSURL string `json:"jwksUrl,omitempty"` + UserInfoURL string `json:"userInfoUrl,omitempty"` + Algorithms []string `json:"signatureAlgorithms,omitempty"` } type Config struct { @@ -40,19 +41,38 @@ type OIDCOptions struct { Config *Config `json:"config,omitempty"` TransformTemplate string `json:"transform,omitempty"` - oidcProviderConfig *oidc.ProviderConfig target *template.Template transform *template.Template + oidcProviderConfig *oidc.ProviderConfig + provider *oidc.Provider + verifier *oidc.IDTokenVerifier } -func (o *OIDCOptions) GetProvider(ctx context.Context) *oidc.Provider { - if o == nil || o.Provider == nil || o.oidcProviderConfig == nil { - return nil +func (o *OIDCOptions) GetVerifier(ctx context.Context) (*oidc.IDTokenVerifier, error) { + if o.verifier == nil { + switch { + case o.Provider.DiscoveryBaseURL != "": + // creates a new OIDC provider using automatic discovery and the default HTTP client + provider, err := oidc.NewProvider(ctx, o.Provider.DiscoveryBaseURL) + if err != nil { + return nil, fmt.Errorf("failed creating new OIDC provider using discovery: %w", err) + } + o.provider = provider + default: + o.provider = o.oidcProviderConfig.NewProvider(ctx) + } + + if o.provider == nil { + return nil, errors.New("no OIDC provider available") + } + + o.verifier = o.provider.Verifier(o.getConfig()) } - return o.oidcProviderConfig.NewProvider(ctx) + + return o.verifier, nil } -func (o *OIDCOptions) GetConfig() *oidc.Config { +func (o *OIDCOptions) getConfig() *oidc.Config { if o == nil || o.Config == nil { return &oidc.Config{} } @@ -74,13 +94,15 @@ func (o *OIDCOptions) validateAndInitialize() (err error) { if o.Provider == nil { return errors.New("provider not set") } - if o.Provider.IssuerURL == "" { - return errors.New("issuer URL must not be empty") + if o.Provider.IssuerURL == "" && o.Provider.DiscoveryBaseURL == "" { + return errors.New("either OIDC discovery or issuer URL must be set") } - o.oidcProviderConfig, err = toOIDCProviderConfig(o.Provider) - if err != nil { - return fmt.Errorf("failed creationg OIDC provider config: %w", err) + if o.Provider.DiscoveryBaseURL == "" { + o.oidcProviderConfig, err = toOIDCProviderConfig(o.Provider) + if err != nil { + return fmt.Errorf("failed creationg OIDC provider config: %w", err) + } } o.target, err = template.New("DeviceID").Parse(o.Provider.IssuerURL) diff --git a/authority/provisioner/wire/oidc_options_test.go b/authority/provisioner/wire/oidc_options_test.go index 8b3eaa75e..a0bf17e9b 100644 --- a/authority/provisioner/wire/oidc_options_test.go +++ b/authority/provisioner/wire/oidc_options_test.go @@ -1,11 +1,20 @@ package wire import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" "testing" "text/template" + "github.com/coreos/go-oidc/v3/oidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" ) func TestOIDCOptions_Transform(t *testing.T) { @@ -119,3 +128,178 @@ func TestOIDCOptions_Transform(t *testing.T) { }) } } + +func TestOIDCOptions_EvaluateTarget(t *testing.T) { + tu := "http://target.example.com/{{.DeviceID}}" + target, err := template.New("DeviceID").Parse(tu) + require.NoError(t, err) + empty := "http://target.example.com" + emptyTarget, err := template.New("DeviceID").Parse(empty) + require.NoError(t, err) + fail := "https:/wire.com:15958/clients/{{.DeviceId}}/access-token" + failTarget, err := template.New("DeviceID").Parse(fail) + require.NoError(t, err) + type fields struct { + target *template.Template + } + type args struct { + deviceID string + } + tests := []struct { + name string + fields fields + args args + want string + expectedErr error + }{ + { + name: "ok", fields: fields{target: target}, args: args{deviceID: "deviceID"}, want: "http://target.example.com/deviceID", + }, + { + name: "ok/empty", fields: fields{target: emptyTarget}, args: args{deviceID: ""}, want: "http://target.example.com", + }, + { + name: "fail/template", fields: fields{target: failTarget}, args: args{deviceID: "bla"}, expectedErr: errors.New(`failed executing OIDC template: template: DeviceID:1:32: executing "DeviceID" at <.DeviceId>: can't evaluate field DeviceId in type struct { DeviceID string }`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &OIDCOptions{ + target: tt.fields.target, + } + got, err := o.EvaluateTarget(tt.args.deviceID) + if tt.expectedErr != nil { + assert.EqualError(t, err, tt.expectedErr.Error()) + assert.Empty(t, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestOIDCOptions_GetVerifier(t *testing.T) { + signerJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + require.NoError(t, err) + require.NoError(t, err) + srv := mustDiscoveryServer(t, signerJWK.Public()) + defer srv.Close() + type fields struct { + Provider *Provider + Config *Config + TransformTemplate string + } + tests := []struct { + name string + fields fields + ctx context.Context + want *oidc.IDTokenVerifier + wantErr bool + }{ + { + name: "fail/invalid-discovery-url", + fields: fields{ + Provider: &Provider{ + DiscoveryBaseURL: "http://invalid.example.com", + }, + Config: &Config{ + ClientID: "client-id", + }, + TransformTemplate: "http://target.example.com/{{.DeviceID}}", + }, + ctx: context.Background(), + wantErr: true, + }, + { + name: "ok/auto", + fields: fields{ + Provider: &Provider{ + DiscoveryBaseURL: srv.URL, + }, + Config: &Config{ + ClientID: "client-id", + }, + TransformTemplate: "http://target.example.com/{{.DeviceID}}", + }, + ctx: context.Background(), + }, + { + name: "ok/fixed", + fields: fields{ + Provider: &Provider{ + IssuerURL: "http://issuer.example.com", + }, + Config: &Config{ + ClientID: "client-id", + }, + TransformTemplate: "http://target.example.com/{{.DeviceID}}", + }, + ctx: context.Background(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &OIDCOptions{ + Provider: tt.fields.Provider, + Config: tt.fields.Config, + TransformTemplate: tt.fields.TransformTemplate, + } + + err := o.validateAndInitialize() + require.NoError(t, err) + + verifier, err := o.GetVerifier(tt.ctx) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, verifier) + return + } + + assert.NoError(t, err) + assert.NotNil(t, verifier) + if assert.NotNil(t, o.provider) { + assert.NotNil(t, o.provider.Endpoint()) + } + }) + } +} + +func mustDiscoveryServer(t *testing.T, pub jose.JSONWebKey) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + server := httptest.NewServer(mux) + b, err := json.Marshal(struct { + Keys []jose.JSONWebKey `json:"keys,omitempty"` + }{ + Keys: []jose.JSONWebKey{pub}, + }) + require.NoError(t, err) + jwks := string(b) + + wellKnown := fmt.Sprintf(`{ + "issuer": "%[1]s", + "authorization_endpoint": "%[1]s/auth", + "token_endpoint": "%[1]s/token", + "jwks_uri": "%[1]s/keys", + "userinfo_endpoint": "%[1]s/userinfo", + "id_token_signing_alg_values_supported": ["ES256"] + }`, server.URL) + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) { + _, err := io.WriteString(w, wellKnown) + if err != nil { + w.WriteHeader(500) + } + }) + mux.HandleFunc("/keys", func(w http.ResponseWriter, req *http.Request) { + _, err := io.WriteString(w, jwks) + if err != nil { + w.WriteHeader(500) + } + }) + + t.Cleanup(server.Close) + return server +} diff --git a/authority/provisioner/wire/wire_options_test.go b/authority/provisioner/wire/wire_options_test.go index fd0acf020..c9fc844b4 100644 --- a/authority/provisioner/wire/wire_options_test.go +++ b/authority/provisioner/wire/wire_options_test.go @@ -55,7 +55,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= }, DPOP: &DPOPOptions{}, }, - expectedErr: errors.New("failed initializing OIDC options: issuer URL must not be empty"), + expectedErr: errors.New("failed initializing OIDC options: either OIDC discovery or issuer URL must be set"), }, { name: "fail/invalid-issuer-url",