Skip to content

Commit

Permalink
Add support for querying tokens by service name. (#18667)
Browse files Browse the repository at this point in the history
Add support for querying tokens by service name

The consul-k8s endpoints controller has a workflow where it fetches all tokens.
This is not performant for large clusters, where there may be a sizable number
of tokens. This commit attempts to alleviate that problem and introduces a new
way to query by the token's service name.
  • Loading branch information
hashi-derek committed Sep 6, 2023
1 parent 14794cc commit 1251265
Show file tree
Hide file tree
Showing 14 changed files with 378 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .changelog/18667.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
api: Add support for listing ACL tokens by service name.
```
1 change: 1 addition & 0 deletions agent/acl_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ func (s *HTTPHandlers) ACLTokenList(resp http.ResponseWriter, req *http.Request)
args.Policy = req.URL.Query().Get("policy")
args.Role = req.URL.Query().Get("role")
args.AuthMethod = req.URL.Query().Get("authmethod")
args.ServiceName = req.URL.Query().Get("servicename")
if err := parseACLAuthMethodEnterpriseMeta(req, &args.ACLAuthMethodEnterpriseMeta); err != nil {
return nil, err
}
Expand Down
32 changes: 32 additions & 0 deletions agent/acl_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,38 @@ func TestACL_HTTP(t *testing.T) {
require.Error(t, err)
testutil.RequireErrorContains(t, err, "Only lowercase alphanumeric")
})

t.Run("Create with valid service identity", func(t *testing.T) {
tokenInput := &structs.ACLToken{
Description: "token for service identity sn1",
ServiceIdentities: []*structs.ACLServiceIdentity{
{
ServiceName: "sn1",
},
},
}

req, _ := http.NewRequest("PUT", "/v1/acl/token", jsonBody(tokenInput))
req.Header.Add("X-Consul-Token", "root")
resp := httptest.NewRecorder()
_, err := a.srv.ACLTokenCreate(resp, req)
require.NoError(t, err)
})

t.Run("List by ServiceName", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/v1/acl/tokens?servicename=sn1", nil)
req.Header.Add("X-Consul-Token", "root")
resp := httptest.NewRecorder()
raw, err := a.srv.ACLTokenList(resp, req)
require.NoError(t, err)
tokens, ok := raw.(structs.ACLTokenListStubs)
require.True(t, ok)
require.Len(t, tokens, 1)
token := tokens[0]
require.Equal(t, "token for service identity sn1", token.Description)
require.Len(t, token.ServiceIdentities, 1)
require.Equal(t, "sn1", token.ServiceIdentities[0].ServiceName)
})
})
}

Expand Down
14 changes: 12 additions & 2 deletions agent/consul/acl_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,18 @@ func (a *ACL) TokenList(args *structs.ACLTokenListRequest, reply *structs.ACLTok
}

return a.srv.blockingQuery(&args.QueryOptions, &reply.QueryMeta,
func(ws memdb.WatchSet, state *state.Store) error {
index, tokens, err := state.ACLTokenList(ws, args.IncludeLocal, args.IncludeGlobal, args.Policy, args.Role, args.AuthMethod, methodMeta, &args.EnterpriseMeta)
func(ws memdb.WatchSet, s *state.Store) error {
index, tokens, err := s.ACLTokenListWithParameters(ws, state.ACLTokenListParameters{
Local: args.IncludeLocal,
Global: args.IncludeGlobal,
Policy: args.Policy,
Role: args.Role,
MethodName: args.AuthMethod,
ServiceName: args.ServiceName,
MethodMeta: methodMeta,
EnterpriseMeta: &args.EnterpriseMeta,
})

if err != nil {
return err
}
Expand Down
5 changes: 5 additions & 0 deletions agent/consul/acl_replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,10 @@ func TestACLReplication_Tokens(t *testing.T) {

checkSame := func(t *retry.R) {
// only account for global tokens - local tokens shouldn't be replicated
// nolint:staticcheck
index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "", nil, nil)
require.NoError(t, err)
// nolint:staticcheck
_, local, err := s2.fsm.State().ACLTokenList(nil, false, true, "", "", "", nil, nil)
require.NoError(t, err)

Expand Down Expand Up @@ -480,6 +482,7 @@ func TestACLReplication_Tokens(t *testing.T) {
})

// verify dc2 local tokens didn't get blown away
// nolint:staticcheck
_, local, err := s2.fsm.State().ACLTokenList(nil, true, false, "", "", "", nil, nil)
require.NoError(t, err)
require.Len(t, local, 50)
Expand Down Expand Up @@ -818,9 +821,11 @@ func TestACLReplication_AllTypes(t *testing.T) {

checkSameTokens := func(t *retry.R) {
// only account for global tokens - local tokens shouldn't be replicated
// nolint:staticcheck
index, remote, err := s1.fsm.State().ACLTokenList(nil, false, true, "", "", "", nil, nil)
require.NoError(t, err)
// Query for all of them, so that we can prove that no globals snuck in.
// nolint:staticcheck
_, local, err := s2.fsm.State().ACLTokenList(nil, true, true, "", "", "", nil, nil)
require.NoError(t, err)

Expand Down
1 change: 1 addition & 0 deletions agent/consul/acl_replication_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func (r *aclTokenReplicator) FetchRemote(srv *Server, lastRemoteIndex uint64) (i
func (r *aclTokenReplicator) FetchLocal(srv *Server) (int, uint64, error) {
r.local = nil

// nolint:staticcheck
idx, local, err := srv.fsm.State().ACLTokenList(nil, false, true, "", "", "", nil, srv.replicationEnterpriseMeta())
if err != nil {
return 0, 0, err
Expand Down
65 changes: 50 additions & 15 deletions agent/consul/state/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,35 @@ func aclTokenGetTxn(tx ReadTxn, ws memdb.WatchSet, value, index string, entMeta
return nil, nil
}

type ACLTokenListParameters struct {
Local bool
Global bool
Policy string
Role string
ServiceName string
MethodName string
MethodMeta *acl.EnterpriseMeta
EnterpriseMeta *acl.EnterpriseMeta
}

// ACLTokenList return a list of ACL Tokens that match the policy, role, and method.
// This function should be treated as deprecated, and ACLTokenListWithParameters should be preferred.
//
// Deprecated: use ACLTokenListWithParameters
func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role, methodName string, methodMeta, entMeta *acl.EnterpriseMeta) (uint64, structs.ACLTokens, error) {
return s.ACLTokenListWithParameters(ws, ACLTokenListParameters{
Local: local,
Global: global,
Policy: policy,
Role: role,
MethodName: methodName,
MethodMeta: methodMeta,
EnterpriseMeta: entMeta,
})
}

// ACLTokenListWithParameters returns a list of ACL Tokens that match the provided parameters.
func (s *Store) ACLTokenListWithParameters(ws memdb.WatchSet, params ACLTokenListParameters) (uint64, structs.ACLTokens, error) {
tx := s.db.Txn(false)
defer tx.Abort()

Expand All @@ -634,43 +661,51 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role

needLocalityFilter := false

if policy == "" && role == "" && methodName == "" {
if global == local {
iter, err = aclTokenListAll(tx, entMeta)
if params.Policy == "" && params.Role == "" && params.MethodName == "" && params.ServiceName == "" {
if params.Global == params.Local {
iter, err = aclTokenListAll(tx, params.EnterpriseMeta)
} else {
iter, err = aclTokenList(tx, entMeta, local)
iter, err = aclTokenList(tx, params.EnterpriseMeta, params.Local)
}

} else if policy != "" && role == "" && methodName == "" {
iter, err = aclTokenListByPolicy(tx, policy, entMeta)
} else if params.Policy != "" && params.Role == "" && params.MethodName == "" && params.ServiceName == "" {
// Find by policy
iter, err = aclTokenListByPolicy(tx, params.Policy, params.EnterpriseMeta)
needLocalityFilter = true

} else if params.Policy == "" && params.Role != "" && params.MethodName == "" && params.ServiceName == "" {
// Find by role
iter, err = aclTokenListByRole(tx, params.Role, params.EnterpriseMeta)
needLocalityFilter = true

} else if policy == "" && role != "" && methodName == "" {
iter, err = aclTokenListByRole(tx, role, entMeta)
} else if params.Policy == "" && params.Role == "" && params.MethodName != "" && params.ServiceName == "" {
// Find by methodName
iter, err = aclTokenListByAuthMethod(tx, params.MethodName, params.MethodMeta, params.EnterpriseMeta)
needLocalityFilter = true

} else if policy == "" && role == "" && methodName != "" {
iter, err = aclTokenListByAuthMethod(tx, methodName, methodMeta, entMeta)
} else if params.Policy == "" && params.Role == "" && params.MethodName == "" && params.ServiceName != "" {
// Find by the service identity's serviceName
iter, err = aclTokenListByServiceName(tx, params.ServiceName, params.EnterpriseMeta)
needLocalityFilter = true

} else {
return 0, nil, fmt.Errorf("can only filter by one of policy, role, or methodName at a time")
return 0, nil, fmt.Errorf("can only filter by one of policy, role, serviceName, or methodName at a time")
}

if err != nil {
return 0, nil, fmt.Errorf("failed acl token lookup: %v", err)
}

if needLocalityFilter && global != local {
if needLocalityFilter && params.Global != params.Local {
iter = memdb.NewFilterIterator(iter, func(raw interface{}) bool {
token, ok := raw.(*structs.ACLToken)
if !ok {
return true
}

if global && !token.Local {
if params.Global && !token.Local {
return false
} else if local && token.Local {
} else if params.Local && token.Local {
return false
}

Expand All @@ -695,7 +730,7 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role
}

// Get the table index.
idx := aclTokenMaxIndex(tx, nil, entMeta)
idx := aclTokenMaxIndex(tx, nil, params.EnterpriseMeta)
return idx, result, nil
}

Expand Down
4 changes: 4 additions & 0 deletions agent/consul/state/acl_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func aclTokenListByAuthMethod(tx ReadTxn, authMethod string, _, _ *acl.Enterpris
return tx.Get(tableACLTokens, indexAuthMethod, AuthMethodQuery{Value: authMethod})
}

func aclTokenListByServiceName(tx ReadTxn, serviceName string, entMeta *acl.EnterpriseMeta) (memdb.ResultIterator, error) {
return tx.Get(tableACLTokens, indexServiceName, Query{Value: serviceName})
}

func aclTokenDeleteWithToken(tx WriteTxn, token *structs.ACLToken, idx uint64) error {
// remove the token
if err := tx.Delete(tableACLTokens, token); err != nil {
Expand Down
25 changes: 25 additions & 0 deletions agent/consul/state/acl_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (
indexAccessor = "accessor"
indexPolicies = "policies"
indexRoles = "roles"
indexServiceName = "service-name"
indexAuthMethod = "authmethod"
indexLocality = "locality"
indexName = "name"
Expand Down Expand Up @@ -103,6 +104,15 @@ func tokensTableSchema() *memdb.TableSchema {
writeIndex: indexExpiresLocalFromACLToken,
},
},
indexServiceName: {
Name: indexServiceName,
AllowMissing: true,
Unique: false,
Indexer: indexerMulti[Query, *structs.ACLToken]{
readIndex: indexFromQuery,
writeIndexMulti: indexServiceNameFromACLToken,
},
},
},
}
}
Expand Down Expand Up @@ -395,6 +405,21 @@ func indexExpiresFromACLToken(t *structs.ACLToken, local bool) ([]byte, error) {
return b.Bytes(), nil
}

func indexServiceNameFromACLToken(token *structs.ACLToken) ([][]byte, error) {
vals := make([][]byte, 0, len(token.ServiceIdentities))
for _, id := range token.ServiceIdentities {
if id != nil && id.ServiceName != "" {
var b indexBuilder
b.String(strings.ToLower(id.ServiceName))
vals = append(vals, b.Bytes())
}
}
if len(vals) == 0 {
return nil, errMissingValueForIndex
}
return vals, nil
}

func authMethodsTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: tableACLAuthMethods,
Expand Down
Loading

0 comments on commit 1251265

Please sign in to comment.