diff --git a/agent/consul/state/acl.go b/agent/consul/state/acl.go index 8854a877d..a8c2fccaa 100644 --- a/agent/consul/state/acl.go +++ b/agent/consul/state/acl.go @@ -711,13 +711,12 @@ func (s *Store) ACLTokenList(ws memdb.WatchSet, local, global bool, policy, role // all tokens so our checks just ensure that global == local needLocalityFilter := false + if policy == "" && role == "" && methodName == "" { if global == local { iter, err = aclTokenListAll(tx, entMeta) - } else if global { - iter, err = aclTokenListGlobal(tx, entMeta) } else { - iter, err = aclTokenListLocal(tx, entMeta) + iter, err = aclTokenList(tx, entMeta, local) } } else if policy != "" && role == "" && methodName == "" { @@ -1769,3 +1768,25 @@ func aclAuthMethodDeleteTxn(tx WriteTxn, idx uint64, name string, entMeta *struc return aclAuthMethodDeleteWithMethod(tx, method, idx) } + +func aclTokenList(tx ReadTxn, entMeta *structs.EnterpriseMeta, locality bool) (memdb.ResultIterator, error) { + // TODO: accept non-pointer value + if entMeta == nil { + entMeta = structs.DefaultEnterpriseMetaInDefaultPartition() + } + // if the namespace is the wildcard that will also be handled as the local index uses + // the NamespaceMultiIndex instead of the NamespaceIndex + q := BoolQuery{ + Value: locality, + EnterpriseMeta: *entMeta, + } + return tx.Get(tableACLTokens, indexLocality, q) +} + +// intFromBool returns 1 if cond is true, 0 otherwise. +func intFromBool(cond bool) byte { + if cond { + return 1 + } + return 0 +} diff --git a/agent/consul/state/acl_oss.go b/agent/consul/state/acl_oss.go index faa9d0cb5..30315a3e9 100644 --- a/agent/consul/state/acl_oss.go +++ b/agent/consul/state/acl_oss.go @@ -78,14 +78,6 @@ func aclTokenListAll(tx ReadTxn, _ *structs.EnterpriseMeta) (memdb.ResultIterato return tx.Get(tableACLTokens, "id") } -func aclTokenListLocal(tx ReadTxn, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get(tableACLTokens, "local", true) -} - -func aclTokenListGlobal(tx ReadTxn, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { - return tx.Get(tableACLTokens, "local", false) -} - func aclTokenListByPolicy(tx ReadTxn, policy string, _ *structs.EnterpriseMeta) (memdb.ResultIterator, error) { return tx.Get(tableACLTokens, indexPolicies, Query{Value: policy}) } diff --git a/agent/consul/state/acl_schema.go b/agent/consul/state/acl_schema.go index fae87af88..6c7a7ca0f 100644 --- a/agent/consul/state/acl_schema.go +++ b/agent/consul/state/acl_schema.go @@ -20,7 +20,7 @@ const ( indexPolicies = "policies" indexRoles = "roles" indexAuthMethod = "authmethod" - indexLocal = "local" + indexLocality = "locality" indexName = "name" ) @@ -75,17 +75,13 @@ func tokensTableSchema() *memdb.TableSchema { writeIndex: writeIndex(indexAuthMethodFromACLToken), }, }, - indexLocal: { - Name: indexLocal, + indexLocality: { + Name: indexLocality, AllowMissing: false, Unique: false, - Indexer: &memdb.ConditionalIndex{ - Conditional: func(obj interface{}) (bool, error) { - if token, ok := obj.(*structs.ACLToken); ok { - return token.Local, nil - } - return false, nil - }, + Indexer: indexerSingle{ + readIndex: readIndex(indexFromBoolQuery), + writeIndex: writeIndex(indexLocalFromACLToken), }, }, "expires-global": { @@ -406,3 +402,24 @@ func indexRolesFromACLToken(raw interface{}) ([][]byte, error) { return vals, nil } + +func indexFromBoolQuery(raw interface{}) ([]byte, error) { + q, ok := raw.(BoolQuery) + if !ok { + return nil, fmt.Errorf("unexpected type %T for BoolQuery index", raw) + } + var b indexBuilder + b.Bool(q.Value) + return b.Bytes(), nil +} + +func indexLocalFromACLToken(raw interface{}) ([]byte, error) { + p, ok := raw.(*structs.ACLToken) + if !ok { + return nil, fmt.Errorf("unexpected type %T for structs.ACLPolicy index", raw) + } + + var b indexBuilder + b.Bool(p.Local) + return b.Bytes(), nil +} diff --git a/agent/consul/state/indexer.go b/agent/consul/state/indexer.go index 044306d8e..9dca91b03 100644 --- a/agent/consul/state/indexer.go +++ b/agent/consul/state/indexer.go @@ -133,3 +133,7 @@ func (b *indexBuilder) Raw(v []byte) { func (b *indexBuilder) Bytes() []byte { return (*bytes.Buffer)(b).Bytes() } + +func (b *indexBuilder) Bool(v bool) { + b.Raw([]byte{intFromBool(v)}) +} diff --git a/agent/consul/state/query.go b/agent/consul/state/query.go index d264e98aa..799a8f019 100644 --- a/agent/consul/state/query.go +++ b/agent/consul/state/query.go @@ -96,6 +96,18 @@ type BoolQuery struct { structs.EnterpriseMeta } +// NamespaceOrDefault exists because structs.EnterpriseMeta uses a pointer +// receiver for this method. Remove once that is fixed. +func (q BoolQuery) NamespaceOrDefault() string { + return q.EnterpriseMeta.NamespaceOrDefault() +} + +// PartitionOrDefault exists because structs.EnterpriseMeta uses a pointer +// receiver for this method. Remove once that is fixed. +func (q BoolQuery) PartitionOrDefault() string { + return q.EnterpriseMeta.PartitionOrDefault() +} + // KeyValueQuery is a type used to query for both a key and a value that may // include an enterprise identifier. type KeyValueQuery struct {