EventPublisher: Make Unsubscribe a function on Subscription

It is critical that Unsubscribe be called with the same pointer to a
SubscriptionRequest that was used to create the Subscription. The
docstring made that clear, but it sill allowed a caler to get it wrong by
creating a new SubscriptionRequest.

By hiding this detail from the caller, and only exposing an Unsubscribe
method, it should be impossible to fail to Unsubscribe.

Also update some godoc strings.
This commit is contained in:
Daniel Nephin 2020-06-18 18:29:06 -04:00
parent 86976cf23c
commit 02bc5a26e4
3 changed files with 39 additions and 31 deletions

View file

@ -51,12 +51,16 @@ type EventPublisher struct {
} }
type subscriptions struct { type subscriptions struct {
// lock for byToken. If both subscription.lock and EventPublisher.lock need
// to be held, EventPublisher.lock MUST always be acquired first.
lock sync.RWMutex lock sync.RWMutex
// subsByToken stores a list of Subscription objects outstanding indexed by a // byToken is an mapping of active Subscriptions indexed by a the token and
// hash of the ACL token they used to subscribe so we can reload them if their // a pointer to the request.
// ACL permissions change. // When the token is modified all subscriptions under that token will be
subsByToken map[string]map[*stream.SubscribeRequest]*stream.Subscription // reloaded.
// A subscription may be unsubscribed by using the pointer to the request.
byToken map[string]map[*stream.SubscribeRequest]*stream.Subscription
} }
type commitUpdate struct { type commitUpdate struct {
@ -70,7 +74,7 @@ func NewEventPublisher(handlers map[stream.Topic]topicHandler, snapCacheTTL time
snapCache: make(map[stream.Topic]map[string]*stream.EventSnapshot), snapCache: make(map[stream.Topic]map[string]*stream.EventSnapshot),
publishCh: make(chan commitUpdate, 64), publishCh: make(chan commitUpdate, 64),
subscriptions: &subscriptions{ subscriptions: &subscriptions{
subsByToken: make(map[string]map[*stream.SubscribeRequest]*stream.Subscription), byToken: make(map[string]map[*stream.SubscribeRequest]*stream.Subscription),
}, },
handlers: handlers, handlers: handlers,
} }
@ -160,8 +164,8 @@ func (s *subscriptions) handleACLUpdate(tx ReadTxn, event stream.Event) error {
switch event.Topic { switch event.Topic {
case stream.Topic_ACLTokens: case stream.Topic_ACLTokens:
token := event.Payload.(*structs.ACLToken) token := event.Payload.(*structs.ACLToken)
for _, sub := range s.subsByToken[token.SecretID] { for _, sub := range s.byToken[token.SecretID] {
sub.CloseReload() sub.ForceReload()
} }
case stream.Topic_ACLPolicies: case stream.Topic_ACLPolicies:
@ -199,13 +203,13 @@ func (s *subscriptions) handleACLUpdate(tx ReadTxn, event stream.Event) error {
return nil return nil
} }
// This method requires the EventPublisher.lock is held // This method requires the subscriptions.lock.RLock is held (the read-only lock)
func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator) { func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator) {
for token := tokens.Next(); token != nil; token = tokens.Next() { for token := tokens.Next(); token != nil; token = tokens.Next() {
token := token.(*structs.ACLToken) token := token.(*structs.ACLToken)
if subs, ok := s.subsByToken[token.SecretID]; ok { if subs, ok := s.byToken[token.SecretID]; ok {
for _, sub := range subs { for _, sub := range subs {
sub.CloseReload() sub.ForceReload()
} }
} }
} }
@ -218,8 +222,8 @@ func (s *subscriptions) closeSubscriptionsForTokens(tokens memdb.ResultIterator)
// decides it can no longer maintain correct operation for example if ACL // decides it can no longer maintain correct operation for example if ACL
// policies changed or the state store was restored. // policies changed or the state store was restored.
// //
// When the called is finished with the subscription for any reason, it must // When the caller is finished with the subscription for any reason, it must
// call Unsubscribe to free ACL tracking resources. // call Subscription.Unsubscribe to free ACL tracking resources.
func (e *EventPublisher) Subscribe( func (e *EventPublisher) Subscribe(
ctx context.Context, ctx context.Context,
req *stream.SubscribeRequest, req *stream.SubscribeRequest,
@ -278,7 +282,12 @@ func (e *EventPublisher) Subscribe(
} }
e.subscriptions.add(req, sub) e.subscriptions.add(req, sub)
// Set unsubscribe so that the caller doesn't need to keep track of the
// SubscriptionRequest, and can not accidentally call unsubscribe with the
// wrong value.
sub.Unsubscribe = func() {
e.subscriptions.unsubscribe(req)
}
return sub, nil return sub, nil
} }
@ -286,28 +295,29 @@ func (s *subscriptions) add(req *stream.SubscribeRequest, sub *stream.Subscripti
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
subsByToken, ok := s.subsByToken[req.Token] subsByToken, ok := s.byToken[req.Token]
if !ok { if !ok {
subsByToken = make(map[*stream.SubscribeRequest]*stream.Subscription) subsByToken = make(map[*stream.SubscribeRequest]*stream.Subscription)
s.subsByToken[req.Token] = subsByToken s.byToken[req.Token] = subsByToken
} }
subsByToken[req] = sub subsByToken[req] = sub
} }
// Unsubscribe must be called when a client is no longer interested in a // unsubscribe must be called when a client is no longer interested in a
// subscription to free resources monitoring changes in it's ACL token. The same // subscription to free resources monitoring changes in it's ACL token.
// request object passed to Subscribe must be used. //
func (s *subscriptions) Unsubscribe(req *stream.SubscribeRequest) { // req MUST be the same pointer that was used to register the subscription.
func (s *subscriptions) unsubscribe(req *stream.SubscribeRequest) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
subsByToken, ok := s.subsByToken[req.Token] subsByToken, ok := s.byToken[req.Token]
if !ok { if !ok {
return return
} }
delete(subsByToken, req) delete(subsByToken, req)
if len(subsByToken) == 0 { if len(subsByToken) == 0 {
delete(s.subsByToken, req.Token) delete(s.byToken, req.Token)
} }
} }

View file

@ -42,6 +42,10 @@ type Subscription struct {
// cancelFn stores the context cancel function that will wake up the // cancelFn stores the context cancel function that will wake up the
// in-progress Next call on a server-initiated state change e.g. Reload. // in-progress Next call on a server-initiated state change e.g. Reload.
cancelFn func() cancelFn func()
// Unsubscribe is a function set by EventPublisher that is called to
// free resources when the subscription is no longer needed.
Unsubscribe func()
} }
type SubscribeRequest struct { type SubscribeRequest struct {
@ -116,9 +120,9 @@ func (s *Subscription) Next() ([]Event, error) {
} }
} }
// CloseReload closes the stream and signals that the subscriber should reload. // ForceReload closes the stream and signals that the subscriber should reload.
// It is safe to call from any goroutine. // It is safe to call from any goroutine.
func (s *Subscription) CloseReload() { func (s *Subscription) ForceReload() {
swapped := atomic.CompareAndSwapUint32(&s.state, SubscriptionStateOpen, swapped := atomic.CompareAndSwapUint32(&s.state, SubscriptionStateOpen,
SubscriptionStateCloseReload) SubscriptionStateCloseReload)
@ -126,9 +130,3 @@ func (s *Subscription) CloseReload() {
s.cancelFn() s.cancelFn()
} }
} }
// Request returns the request object that started the subscription.
// TODO: remove
func (s *Subscription) Request() *SubscribeRequest {
return s.req
}

View file

@ -118,11 +118,11 @@ func TestSubscriptionCloseReload(t *testing.T) {
require.Len(t, got, 1) require.Len(t, got, 1)
require.Equal(t, index, got[0].Index) require.Equal(t, index, got[0].Index)
// Schedule a CloseReload simulating the server deciding this subscroption // Schedule a ForceReload simulating the server deciding this subscroption
// needs to reset (e.g. on ACL perm change). // needs to reset (e.g. on ACL perm change).
start = time.Now() start = time.Now()
time.AfterFunc(200*time.Millisecond, func() { time.AfterFunc(200*time.Millisecond, func() {
sub.CloseReload() sub.ForceReload()
}) })
_, err = sub.Next() _, err = sub.Next()