package spanner import ( "context" "fmt" "sync" "time" "cloud.google.com/go/spanner" metrics "github.com/armon/go-metrics" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/physical" "github.com/pkg/errors" "google.golang.org/grpc/codes" ) // Verify Backend satisfies the correct interfaces var ( _ physical.HABackend = (*Backend)(nil) _ physical.Lock = (*Lock)(nil) ) const ( // LockRenewInterval is the time to wait between lock renewals. LockRenewInterval = 5 * time.Second // LockRetryInterval is the amount of time to wait if the lock fails before // trying again. LockRetryInterval = 5 * time.Second // LockTTL is the default lock TTL. LockTTL = 15 * time.Second // LockWatchRetryInterval is the amount of time to wait if a watch fails // before trying again. LockWatchRetryInterval = 5 * time.Second // LockWatchRetryMax is the number of times to retry a failed watch before // signaling that leadership is lost. LockWatchRetryMax = 5 ) var ( // metricLockUnlock is the metric to register for a lock delete. metricLockUnlock = []string{"spanner", "lock", "unlock"} // metricLockGet is the metric to register for a lock get. metricLockLock = []string{"spanner", "lock", "lock"} // metricLockValue is the metric to register for a lock create/update. metricLockValue = []string{"spanner", "lock", "value"} ) // Lock is the HA lock. type Lock struct { // backend is the underlying physical backend. backend *Backend // key is the name of the key. value is the value of the key. key, value string // held is a boolean indicating if the lock is currently held. held bool // identity is the internal identity of this key (unique to this server // instance). identity string // lock is an internal lock lock sync.Mutex // stopCh is the channel that stops all operations. It may be closed in the // event of a leader loss or graceful shutdown. stopped is a boolean // indicating if we are stopped - it exists to prevent double closing the // channel. stopLock is a mutex around the locks. stopCh chan struct{} stopped bool stopLock sync.Mutex // Allow modifying the Lock durations for ease of unit testing. renewInterval time.Duration retryInterval time.Duration ttl time.Duration watchRetryInterval time.Duration watchRetryMax int } // LockRecord is the struct that corresponds to a lock. type LockRecord struct { Key string Value string Identity string Timestamp time.Time } // HAEnabled implements HABackend and indicates that this backend supports high // availability. func (b *Backend) HAEnabled() bool { return b.haEnabled } // LockWith acquires a mutual exclusion based on the given key. func (b *Backend) LockWith(key, value string) (physical.Lock, error) { identity, err := uuid.GenerateUUID() if err != nil { return nil, fmt.Errorf("lock with: %w", err) } return &Lock{ backend: b, key: key, value: value, identity: identity, stopped: true, renewInterval: LockRenewInterval, retryInterval: LockRetryInterval, ttl: LockTTL, watchRetryInterval: LockWatchRetryInterval, watchRetryMax: LockWatchRetryMax, }, nil } // Lock acquires the given lock. The stopCh is optional. If closed, it // interrupts the lock acquisition attempt. The returned channel should be // closed when leadership is lost. func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { defer metrics.MeasureSince(metricLockLock, time.Now()) l.lock.Lock() defer l.lock.Unlock() if l.held { return nil, errors.New("lock already held") } // Attempt to lock - this function blocks until a lock is acquired or an error // occurs. acquired, err := l.attemptLock(stopCh) if err != nil { return nil, fmt.Errorf("lock: %w", err) } if !acquired { return nil, nil } // We have the lock now l.held = true // Build the locks l.stopLock.Lock() l.stopCh = make(chan struct{}) l.stopped = false l.stopLock.Unlock() // Periodically renew and watch the lock go l.renewLock() go l.watchLock() return l.stopCh, nil } // Unlock releases the lock. func (l *Lock) Unlock() error { defer metrics.MeasureSince(metricLockUnlock, time.Now()) l.lock.Lock() defer l.lock.Unlock() if !l.held { return nil } // Stop any existing locking or renewal attempts l.stopLock.Lock() if !l.stopped { l.stopped = true close(l.stopCh) } l.stopLock.Unlock() // Delete ctx := context.Background() if _, err := l.backend.haClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { row, err := txn.ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Identity"}) if err != nil { if spanner.ErrCode(err) != codes.NotFound { return nil } return err } var r LockRecord if derr := row.ToStruct(&r); derr != nil { return fmt.Errorf("failed to decode to struct: %w", derr) } // If the identity is different, that means that between the time that after // we stopped acquisition, the TTL expired and someone else grabbed the // lock. We do not want to delete a lock that is not our own. if r.Identity != l.identity { return nil } return txn.BufferWrite([]*spanner.Mutation{ spanner.Delete(l.backend.haTable, spanner.Key{l.key}), }) }); err != nil { return fmt.Errorf("unlock: %w", err) } // We are no longer holding the lock l.held = false return nil } // Value returns the value of the lock and if it is held. func (l *Lock) Value() (bool, string, error) { defer metrics.MeasureSince(metricLockValue, time.Now()) r, err := l.get(context.Background()) if err != nil { return false, "", err } if r == nil { return false, "", err } return true, string(r.Value), nil } // attemptLock attempts to acquire a lock. If the given channel is closed, the // acquisition attempt stops. This function returns when a lock is acquired or // an error occurs. func (l *Lock) attemptLock(stopCh <-chan struct{}) (bool, error) { ticker := time.NewTicker(l.retryInterval) defer ticker.Stop() for { select { case <-ticker.C: acquired, err := l.writeLock() if err != nil { return false, fmt.Errorf("attempt lock: %w", err) } if !acquired { continue } return true, nil case <-stopCh: return false, nil } } } // renewLock renews the given lock until the channel is closed. func (l *Lock) renewLock() { ticker := time.NewTicker(l.renewInterval) defer ticker.Stop() for { select { case <-ticker.C: l.writeLock() case <-l.stopCh: return } } } // watchLock checks whether the lock has changed in the table and closes the // leader channel accordingly. If an error occurs during the check, watchLock // will retry the operation and then close the leader channel if it can't // succeed after retries. func (l *Lock) watchLock() { retries := 0 ticker := time.NewTicker(l.watchRetryInterval) OUTER: for { // Check if the channel is already closed select { case <-l.stopCh: break OUTER default: } // Check if we've exceeded retries if retries >= l.watchRetryMax-1 { break OUTER } // Wait for the timer select { case <-ticker.C: case <-l.stopCh: break OUTER } // Attempt to read the key r, err := l.get(context.Background()) if err != nil { retries++ continue } // Verify the identity is the same if r == nil || r.Identity != l.identity { break OUTER } } l.stopLock.Lock() defer l.stopLock.Unlock() if !l.stopped { l.stopped = true close(l.stopCh) } } // writeLock writes the given lock using the following algorithm: // // - lock does not exist // - write the lock // // - lock exists // - if key is empty or identity is the same or timestamp exceeds TTL // - update the lock to self func (l *Lock) writeLock() (bool, error) { // Keep track of whether the lock was written lockWritten := false // Create a transaction to read and the update (maybe) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // The transaction will be retried, and it could sit in a queue behind, say, // the delete operation. To stop the transaction, we close the context when // the associated stopCh is received. go func() { select { case <-l.stopCh: cancel() case <-ctx.Done(): } }() _, err := l.backend.haClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { row, err := txn.ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Key", "Identity", "Timestamp"}) if err != nil && spanner.ErrCode(err) != codes.NotFound { return err } // If there was a record, verify that the record is still trustable. if row != nil { var r LockRecord if derr := row.ToStruct(&r); derr != nil { return fmt.Errorf("failed to decode to struct: %w", derr) } // If the key is empty or the identity is ours or the ttl expired, we can // write. Otherwise, return now because we cannot. if r.Key != "" && r.Identity != l.identity && time.Now().UTC().Sub(r.Timestamp) < l.ttl { return nil } } m, err := spanner.InsertOrUpdateStruct(l.backend.haTable, &LockRecord{ Key: l.key, Value: l.value, Identity: l.identity, Timestamp: time.Now().UTC(), }) if err != nil { return fmt.Errorf("failed to generate struct: %w", err) } if err := txn.BufferWrite([]*spanner.Mutation{m}); err != nil { return fmt.Errorf("failed to write: %w", err) } // Mark that the lock was acquired lockWritten = true return nil }) if err != nil { return false, fmt.Errorf("write lock: %w", err) } return lockWritten, nil } // get retrieves the value for the lock. func (l *Lock) get(ctx context.Context) (*LockRecord, error) { // Read row, err := l.backend.haClient.Single().ReadRow(ctx, l.backend.haTable, spanner.Key{l.key}, []string{"Key", "Value", "Timestamp", "Identity"}) if spanner.ErrCode(err) == codes.NotFound { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to read value for %q: %w", l.key, err) } var r LockRecord if err := row.ToStruct(&r); err != nil { return nil, fmt.Errorf("failed to decode lock: %w", err) } return &r, nil }