Skip to content

Commit

Permalink
Remove Disabled provisioner add add an Uninitialized state
Browse files Browse the repository at this point in the history
This commit renames the Disabled provisioner to Uninitialized and adds
an state instead of just a boolean. It also adds tests.
  • Loading branch information
maraino committed Jul 11, 2024
1 parent 3908932 commit 343e730
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 16 deletions.
2 changes: 1 addition & 1 deletion authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (a *Authority) ReloadAdminResources(ctx context.Context) error {
for _, p := range provList {
if err := p.Init(provisionerConfig); err != nil {
log.Printf("failed to initialize %s provisioner %q: %v\n", p.GetType(), p.GetName(), err)
p = provisioner.Disabled{
p = provisioner.Uninitialized{
Interface: p, Reason: err,
}
}
Expand Down
9 changes: 9 additions & 0 deletions authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
EnableSSHCA: &enableSSHCA,
},
},
&provisioner.JWK{
Name: "uninitialized",
Type: "JWK",
Key: clijwk,
Claims: &provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: time.Minute},
},
},
}
c := &Config{
Address: "127.0.0.1:443",
Expand Down
2 changes: 1 addition & 1 deletion authority/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface
return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// If the provisioner is disabled, send an appropriate message to the client
if _, ok := p.(provisioner.Disabled); ok {
if _, ok := p.(provisioner.Uninitialized); ok {
return nil, nil, errs.New(http.StatusUnauthorized, "provisioner %q is disabled due to an initialization error", p.GetName())
}

Expand Down
19 changes: 19 additions & 0 deletions authority/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"

"github.com/google/uuid"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
Expand Down Expand Up @@ -304,6 +305,24 @@ func TestAuthority_authorizeToken(t *testing.T) {
code: http.StatusUnauthorized,
}
},
"fail/uninitialized": func(t *testing.T) *authorizeTest {
cl := jose.Claims{
Subject: "test.smallstep.com",
Issuer: "uninitialized",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: uuid.NewString(),
}
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
token: raw,
err: errors.New(`provisioner "uninitialized" is disabled due to an initialization error`),
code: http.StatusUnauthorized,
}
},
}

for name, genTestCase := range tests {
Expand Down
14 changes: 7 additions & 7 deletions authority/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ type Interface interface {
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
}

// Disabled represents a disabled provisioner. Disabled provisioners are created
// when the Init methods fails.
type Disabled struct {
// Uninitialized represents a disabled provisioner. Uninitialized provisioners
// are created when the Init methods fails.
type Uninitialized struct {
Interface
Reason error
}

// MarshalJSON returns the JSON encoding of the provisioner with the disabled
// reason.
func (p Disabled) MarshalJSON() ([]byte, error) {
func (p Uninitialized) MarshalJSON() ([]byte, error) {
provisionerJSON, err := json.Marshal(p.Interface)
if err != nil {
return nil, err
}
reasonJSON, err := json.Marshal(struct {
Disabled bool `json:"disabled"`
DisabledReason string `json:"disabledReason"`
}{true, p.Reason.Error()})
State string `json:"state"`
StateReason string `json:"stateReason"`
}{"Uninitialized", p.Reason.Error()})
if err != nil {
return nil, err
}
Expand Down
48 changes: 41 additions & 7 deletions authority/provisioner/provisioner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"net/http"
"testing"

"golang.org/x/crypto/ssh"

"github.com/smallstep/assert"
"github.com/go-jose/go-jose/v3"
"github.com/smallstep/certificates/api/render"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
)

func TestType_String(t *testing.T) {
Expand Down Expand Up @@ -149,11 +149,11 @@ func TestDefaultIdentityFunc(t *testing.T) {
identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
assert.Equal(t, tc.err.Error(), err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, identity.Usernames, tc.identity.Usernames)
assert.Equal(t, identity.Usernames, tc.identity.Usernames)
}
}
})
Expand Down Expand Up @@ -243,9 +243,43 @@ func TestUnimplementedMethods(t *testing.T) {
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equal(t, http.StatusUnauthorized, sc.StatusCode())
}
assert.Equal(t, msg, err.Error())
})
}
}

func TestUninitialized_MarshalJSON(t *testing.T) {
p := &JWK{
Name: "bad-provisioner",
Type: "JWK",
Key: &jose.JSONWebKey{
Key: []byte("foo"),
},
}

type fields struct {
Interface Interface
Reason error
}
tests := []struct {
name string
fields fields
want []byte
assertion assert.ErrorAssertionFunc
}{
{"ok", fields{p, errors.New("bad key")}, []byte(`{"type":"JWK","name":"bad-provisioner","key":{"kty":"oct","k":"Zm9v"},"state":"Uninitialized","stateReason":"bad key"}`), assert.NoError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := Uninitialized{
Interface: tt.fields.Interface,
Reason: tt.fields.Reason,
}
assert.Equals(t, err.Error(), msg)
got, err := p.MarshalJSON()
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit 343e730

Please sign in to comment.