// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package api import ( "encoding/json" "fmt" "path" "sync" "time" ) const ( // DefaultSemaphoreSessionName is the Session Name we assign if none is provided DefaultSemaphoreSessionName = "Consul API Semaphore" // DefaultSemaphoreSessionTTL is the default session TTL if no Session is provided // when creating a new Semaphore. This is used because we do not have another // other check to depend upon. DefaultSemaphoreSessionTTL = "15s" // DefaultSemaphoreWaitTime is how long we block for at a time to check if semaphore // acquisition is possible. This affects the minimum time it takes to cancel // a Semaphore acquisition. DefaultSemaphoreWaitTime = 15 * time.Second // DefaultSemaphoreKey is the key used within the prefix to // use for coordination between all the contenders. DefaultSemaphoreKey = ".lock" // SemaphoreFlagValue is a magic flag we set to indicate a key // is being used for a semaphore. It is used to detect a potential // conflict with a lock. SemaphoreFlagValue = 0xe0f69a2baa414de0 ) var ( // ErrSemaphoreHeld is returned if we attempt to double lock ErrSemaphoreHeld = fmt.Errorf("Semaphore already held") // ErrSemaphoreNotHeld is returned if we attempt to unlock a semaphore // that we do not hold. ErrSemaphoreNotHeld = fmt.Errorf("Semaphore not held") // ErrSemaphoreInUse is returned if we attempt to destroy a semaphore // that is in use. ErrSemaphoreInUse = fmt.Errorf("Semaphore in use") // ErrSemaphoreConflict is returned if the flags on a key // used for a semaphore do not match expectation ErrSemaphoreConflict = fmt.Errorf("Existing key does not match semaphore use") ) // Semaphore is used to implement a distributed semaphore // using the Consul KV primitives. type Semaphore struct { c *Client opts *SemaphoreOptions isHeld bool sessionRenew chan struct{} lockSession string l sync.Mutex } // SemaphoreOptions is used to parameterize the Semaphore type SemaphoreOptions struct { Prefix string // Must be set and have write permissions Limit int // Must be set, and be positive Value []byte // Optional, value to associate with the contender entry Session string // Optional, created if not specified SessionName string // Optional, defaults to DefaultLockSessionName SessionTTL string // Optional, defaults to DefaultLockSessionTTL MonitorRetries int // Optional, defaults to 0 which means no retries MonitorRetryTime time.Duration // Optional, defaults to DefaultMonitorRetryTime SemaphoreWaitTime time.Duration // Optional, defaults to DefaultSemaphoreWaitTime SemaphoreTryOnce bool // Optional, defaults to false which means try forever Namespace string `json:",omitempty"` // Optional, defaults to API client config, namespace of ACL token, or "default" namespace } // semaphoreLock is written under the DefaultSemaphoreKey and // is used to coordinate between all the contenders. type semaphoreLock struct { // Limit is the integer limit of holders. This is used to // verify that all the holders agree on the value. Limit int // Holders is a list of all the semaphore holders. // It maps the session ID to true. It is used as a set effectively. Holders map[string]bool } // SemaphorePrefix is used to created a Semaphore which will operate // at the given KV prefix and uses the given limit for the semaphore. // The prefix must have write privileges, and the limit must be agreed // upon by all contenders. func (c *Client) SemaphorePrefix(prefix string, limit int) (*Semaphore, error) { opts := &SemaphoreOptions{ Prefix: prefix, Limit: limit, } return c.SemaphoreOpts(opts) } // SemaphoreOpts is used to create a Semaphore with the given options. // The prefix must have write privileges, and the limit must be agreed // upon by all contenders. If a Session is not provided, one will be created. func (c *Client) SemaphoreOpts(opts *SemaphoreOptions) (*Semaphore, error) { if opts.Prefix == "" { return nil, fmt.Errorf("missing prefix") } if opts.Limit <= 0 { return nil, fmt.Errorf("semaphore limit must be positive") } if opts.SessionName == "" { opts.SessionName = DefaultSemaphoreSessionName } if opts.SessionTTL == "" { opts.SessionTTL = DefaultSemaphoreSessionTTL } else { if _, err := time.ParseDuration(opts.SessionTTL); err != nil { return nil, fmt.Errorf("invalid SessionTTL: %v", err) } } if opts.MonitorRetryTime == 0 { opts.MonitorRetryTime = DefaultMonitorRetryTime } if opts.SemaphoreWaitTime == 0 { opts.SemaphoreWaitTime = DefaultSemaphoreWaitTime } s := &Semaphore{ c: c, opts: opts, } return s, nil } // Acquire attempts to reserve a slot in the semaphore, blocking until // success, interrupted via the stopCh or an error is encountered. // Providing a non-nil stopCh can be used to abort the attempt. // On success, a channel is returned that represents our slot. // This channel could be closed at any time due to session invalidation, // communication errors, operator intervention, etc. It is NOT safe to // assume that the slot is held until Release() unless the Session is specifically // created without any associated health checks. By default Consul sessions // prefer liveness over safety and an application must be able to handle // the session being lost. func (s *Semaphore) Acquire(stopCh <-chan struct{}) (<-chan struct{}, error) { // Hold the lock as we try to acquire s.l.Lock() defer s.l.Unlock() // Check if we already hold the semaphore if s.isHeld { return nil, ErrSemaphoreHeld } // Check if we need to create a session first s.lockSession = s.opts.Session if s.lockSession == "" { sess, err := s.createSession() if err != nil { return nil, fmt.Errorf("failed to create session: %v", err) } s.sessionRenew = make(chan struct{}) s.lockSession = sess session := s.c.Session() go session.RenewPeriodic(s.opts.SessionTTL, sess, nil, s.sessionRenew) // If we fail to acquire the lock, cleanup the session defer func() { if !s.isHeld { close(s.sessionRenew) s.sessionRenew = nil } }() } // Create the contender entry kv := s.c.KV() wOpts := WriteOptions{Namespace: s.opts.Namespace} made, _, err := kv.Acquire(s.contenderEntry(s.lockSession), &wOpts) if err != nil || !made { return nil, fmt.Errorf("failed to make contender entry: %v", err) } // Setup the query options qOpts := QueryOptions{ WaitTime: s.opts.SemaphoreWaitTime, Namespace: s.opts.Namespace, } start := time.Now() attempts := 0 WAIT: // Check if we should quit select { case <-stopCh: return nil, nil default: } // Handle the one-shot mode. if s.opts.SemaphoreTryOnce && attempts > 0 { elapsed := time.Since(start) if elapsed > s.opts.SemaphoreWaitTime { return nil, nil } // Query wait time should not exceed the semaphore wait time qOpts.WaitTime = s.opts.SemaphoreWaitTime - elapsed } attempts++ // Read the prefix pairs, meta, err := kv.List(s.opts.Prefix, &qOpts) if err != nil { return nil, fmt.Errorf("failed to read prefix: %v", err) } // Decode the lock lockPair := s.findLock(pairs) if lockPair.Flags != SemaphoreFlagValue { return nil, ErrSemaphoreConflict } lock, err := s.decodeLock(lockPair) if err != nil { return nil, err } // Verify we agree with the limit if lock.Limit != s.opts.Limit { return nil, fmt.Errorf("semaphore limit conflict (lock: %d, local: %d)", lock.Limit, s.opts.Limit) } // Prune the dead holders s.pruneDeadHolders(lock, pairs) // Check if the lock is held if len(lock.Holders) >= lock.Limit { qOpts.WaitIndex = meta.LastIndex goto WAIT } // Create a new lock with us as a holder lock.Holders[s.lockSession] = true newLock, err := s.encodeLock(lock, lockPair.ModifyIndex) if err != nil { return nil, err } // Attempt the acquisition didSet, _, err := kv.CAS(newLock, &wOpts) if err != nil { return nil, fmt.Errorf("failed to update lock: %v", err) } if !didSet { // Update failed, could have been a race with another contender, // retry the operation goto WAIT } // Watch to ensure we maintain ownership of the slot lockCh := make(chan struct{}) go s.monitorLock(s.lockSession, lockCh) // Set that we own the lock s.isHeld = true // Acquired! All done return lockCh, nil } // Release is used to voluntarily give up our semaphore slot. It is // an error to call this if the semaphore has not been acquired. func (s *Semaphore) Release() error { // Hold the lock as we try to release s.l.Lock() defer s.l.Unlock() // Ensure the lock is actually held if !s.isHeld { return ErrSemaphoreNotHeld } // Set that we no longer own the lock s.isHeld = false // Stop the session renew if s.sessionRenew != nil { defer func() { close(s.sessionRenew) s.sessionRenew = nil }() } // Get and clear the lock session lockSession := s.lockSession s.lockSession = "" // Remove ourselves as a lock holder kv := s.c.KV() key := path.Join(s.opts.Prefix, DefaultSemaphoreKey) wOpts := WriteOptions{Namespace: s.opts.Namespace} qOpts := QueryOptions{Namespace: s.opts.Namespace} READ: pair, _, err := kv.Get(key, &qOpts) if err != nil { return err } if pair == nil { pair = &KVPair{} } lock, err := s.decodeLock(pair) if err != nil { return err } // Create a new lock without us as a holder if _, ok := lock.Holders[lockSession]; ok { delete(lock.Holders, lockSession) newLock, err := s.encodeLock(lock, pair.ModifyIndex) if err != nil { return err } // Swap the locks didSet, _, err := kv.CAS(newLock, &wOpts) if err != nil { return fmt.Errorf("failed to update lock: %v", err) } if !didSet { goto READ } } // Destroy the contender entry contenderKey := path.Join(s.opts.Prefix, lockSession) if _, err := kv.Delete(contenderKey, &wOpts); err != nil { return err } return nil } // Destroy is used to cleanup the semaphore entry. It is not necessary // to invoke. It will fail if the semaphore is in use. func (s *Semaphore) Destroy() error { // Hold the lock as we try to acquire s.l.Lock() defer s.l.Unlock() // Check if we already hold the semaphore if s.isHeld { return ErrSemaphoreHeld } // List for the semaphore kv := s.c.KV() q := QueryOptions{Namespace: s.opts.Namespace} pairs, _, err := kv.List(s.opts.Prefix, &q) if err != nil { return fmt.Errorf("failed to read prefix: %v", err) } // Find the lock pair, bail if it doesn't exist lockPair := s.findLock(pairs) if lockPair.ModifyIndex == 0 { return nil } if lockPair.Flags != SemaphoreFlagValue { return ErrSemaphoreConflict } // Decode the lock lock, err := s.decodeLock(lockPair) if err != nil { return err } // Prune the dead holders s.pruneDeadHolders(lock, pairs) // Check if there are any holders if len(lock.Holders) > 0 { return ErrSemaphoreInUse } // Attempt the delete w := WriteOptions{Namespace: s.opts.Namespace} didRemove, _, err := kv.DeleteCAS(lockPair, &w) if err != nil { return fmt.Errorf("failed to remove semaphore: %v", err) } if !didRemove { return ErrSemaphoreInUse } return nil } // createSession is used to create a new managed session func (s *Semaphore) createSession() (string, error) { session := s.c.Session() se := &SessionEntry{ Name: s.opts.SessionName, TTL: s.opts.SessionTTL, Behavior: SessionBehaviorDelete, } w := WriteOptions{Namespace: s.opts.Namespace} id, _, err := session.Create(se, &w) if err != nil { return "", err } return id, nil } // contenderEntry returns a formatted KVPair for the contender func (s *Semaphore) contenderEntry(session string) *KVPair { return &KVPair{ Key: path.Join(s.opts.Prefix, session), Value: s.opts.Value, Session: session, Flags: SemaphoreFlagValue, } } // findLock is used to find the KV Pair which is used for coordination func (s *Semaphore) findLock(pairs KVPairs) *KVPair { key := path.Join(s.opts.Prefix, DefaultSemaphoreKey) for _, pair := range pairs { if pair.Key == key { return pair } } return &KVPair{Flags: SemaphoreFlagValue} } // decodeLock is used to decode a semaphoreLock from an // entry in Consul func (s *Semaphore) decodeLock(pair *KVPair) (*semaphoreLock, error) { // Handle if there is no lock if pair == nil || pair.Value == nil { return &semaphoreLock{ Limit: s.opts.Limit, Holders: make(map[string]bool), }, nil } l := &semaphoreLock{} if err := json.Unmarshal(pair.Value, l); err != nil { return nil, fmt.Errorf("lock decoding failed: %v", err) } return l, nil } // encodeLock is used to encode a semaphoreLock into a KVPair // that can be PUT func (s *Semaphore) encodeLock(l *semaphoreLock, oldIndex uint64) (*KVPair, error) { enc, err := json.Marshal(l) if err != nil { return nil, fmt.Errorf("lock encoding failed: %v", err) } pair := &KVPair{ Key: path.Join(s.opts.Prefix, DefaultSemaphoreKey), Value: enc, Flags: SemaphoreFlagValue, ModifyIndex: oldIndex, } return pair, nil } // pruneDeadHolders is used to remove all the dead lock holders func (s *Semaphore) pruneDeadHolders(lock *semaphoreLock, pairs KVPairs) { // Gather all the live holders alive := make(map[string]struct{}, len(pairs)) for _, pair := range pairs { if pair.Session != "" { alive[pair.Session] = struct{}{} } } // Remove any holders that are dead for holder := range lock.Holders { if _, ok := alive[holder]; !ok { delete(lock.Holders, holder) } } } // monitorLock is a long running routine to monitor a semaphore ownership // It closes the stopCh if we lose our slot. func (s *Semaphore) monitorLock(session string, stopCh chan struct{}) { defer close(stopCh) kv := s.c.KV() opts := QueryOptions{ RequireConsistent: true, Namespace: s.opts.Namespace, } WAIT: retries := s.opts.MonitorRetries RETRY: pairs, meta, err := kv.List(s.opts.Prefix, &opts) if err != nil { // If configured we can try to ride out a brief Consul unavailability // by doing retries. Note that we have to attempt the retry in a non- // blocking fashion so that we have a clean place to reset the retry // counter if service is restored. if retries > 0 && IsRetryableError(err) { time.Sleep(s.opts.MonitorRetryTime) retries-- opts.WaitIndex = 0 goto RETRY } return } lockPair := s.findLock(pairs) lock, err := s.decodeLock(lockPair) if err != nil { return } s.pruneDeadHolders(lock, pairs) if _, ok := lock.Holders[session]; ok { opts.WaitIndex = meta.LastIndex goto WAIT } }