open-vault/physical/oci/oci_ha.go

552 lines
15 KiB
Go
Raw Normal View History

// Copyright © 2019, Oracle and/or its affiliates.
package oci
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"sync"
"sync/atomic"
"time"
2019-11-07 16:54:34 +00:00
"github.com/armon/go-metrics"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/physical"
"github.com/oracle/oci-go-sdk/objectstorage"
)
// The lock implementation below prioritizes ensuring that there are not 2 primary at any given point in time
// over high availability of the primary instance
// Verify Backend satisfies the correct interfaces
var (
_ physical.HABackend = (*Backend)(nil)
_ physical.Lock = (*Lock)(nil)
)
const (
// LockRenewInterval is the time to wait between lock renewals.
LockRenewInterval = 3 * time.Second
// LockRetryInterval is the amount of time to wait if the lock fails before trying again.
LockRetryInterval = 5 * time.Second
// LockWatchRetryInterval is the amount of time to wait if a watch fails before trying again.
LockWatchRetryInterval = 2 * time.Second
// LockTTL is the default lock TTL.
LockTTL = 15 * time.Second
// LockWatchRetryMax is the number of times to retry a failed watch before signaling that leadership is lost.
LockWatchRetryMax = 4
// LockCacheMinAcceptableAge is minimum cache age in seconds to determine that its safe for a secondary instance
// to acquire lock.
LockCacheMinAcceptableAge = 45 * time.Second
// LockWriteRetriesOnFailures is the number of retries that are made on write 5xx failures.
LockWriteRetriesOnFailures = 4
ObjectStorageCallsReadTimeout = 3 * time.Second
ObjectStorageCallsWriteTimeout = 3 * time.Second
)
type LockCache struct {
// ETag values are unique identifiers generated by the OCI service and changed every time the object is modified.
etag string
lastUpdate time.Time
lockRecord *LockRecord
}
type Lock struct {
// backend is the underlying physical backend.
backend *Backend
// Key is the name of the Key. Value is the Value of the Key.
key, value string
// held is a boolean indicating if the lock is currently held.
held bool
// Identity is the internal Identity of this Key (unique to this server instance).
identity string
internalLock sync.Mutex
// stopCh is the channel that stops all operations. It may be closed in the
// event of a leader loss or graceful shutdown. stopped is a boolean
// indicating if we are stopped - it exists to prevent double closing the
// channel. stopLock is a mutex around the locks.
stopCh chan struct{}
stopped bool
stopLock sync.Mutex
lockRecordCache atomic.Value
// Allow modifying the Lock durations for ease of unit testing.
renewInterval time.Duration
retryInterval time.Duration
ttl time.Duration
watchRetryInterval time.Duration
watchRetryMax int
}
type LockRecord struct {
Key string
Value string
Identity string
}
var (
metricLockUnlock = []string{"oci", "lock", "unlock"}
metricLockLock = []string{"oci", "lock", "lock"}
metricLockValue = []string{"oci", "lock", "Value"}
metricLeaderValue = []string{"oci", "leader", "Value"}
)
func (b *Backend) HAEnabled() bool {
return b.haEnabled
}
// LockWith acquires a mutual exclusion based on the given Key.
func (b *Backend) LockWith(key, value string) (physical.Lock, error) {
identity, err := uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("Lock with: %w", err)
}
return &Lock{
backend: b,
key: key,
value: value,
identity: identity,
stopped: true,
renewInterval: LockRenewInterval,
retryInterval: LockRetryInterval,
ttl: LockTTL,
watchRetryInterval: LockWatchRetryInterval,
watchRetryMax: LockWatchRetryMax,
}, nil
}
func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
l.backend.logger.Debug("Lock() called")
defer metrics.MeasureSince(metricLockLock, time.Now().UTC())
l.internalLock.Lock()
defer l.internalLock.Unlock()
if l.held {
return nil, errors.New("lock already held")
}
// Attempt to lock - this function blocks until a lock is acquired or an error
// occurs.
acquired, err := l.attemptLock(stopCh)
if err != nil {
return nil, fmt.Errorf("lock: %w", err)
}
if !acquired {
return nil, nil
}
// We have the lock now
l.held = true
// Build the locks
l.stopLock.Lock()
l.stopCh = make(chan struct{})
l.stopped = false
l.stopLock.Unlock()
// Periodically renew and watch the lock
go l.renewLock()
go l.watchLock()
return l.stopCh, nil
}
// attemptLock attempts to acquire a lock. If the given channel is closed, the
// acquisition attempt stops. This function returns when a lock is acquired or
// an error occurs.
func (l *Lock) attemptLock(stopCh <-chan struct{}) (bool, error) {
l.backend.logger.Debug("AttemptLock() called")
ticker := time.NewTicker(l.retryInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
acquired, err := l.writeLock()
if err != nil {
return false, fmt.Errorf("attempt lock: %w", err)
}
if !acquired {
continue
}
return true, nil
case <-stopCh:
return false, nil
}
}
}
// renewLock renews the given lock until the channel is closed.
func (l *Lock) renewLock() {
l.backend.logger.Debug("RenewLock() called")
ticker := time.NewTicker(l.renewInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.writeLock()
case <-l.stopCh:
return
}
}
}
func loadLockRecordCache(l *Lock) *LockCache {
lockRecordCache := l.lockRecordCache.Load()
if lockRecordCache == nil {
return nil
}
return lockRecordCache.(*LockCache)
}
// watchLock checks whether the lock has changed in the table and closes the
// leader channel accordingly. If an error occurs during the check, watchLock
// will retry the operation and then close the leader channel if it can't
// succeed after retries.
func (l *Lock) watchLock() {
l.backend.logger.Debug("WatchLock() called")
retries := 0
ticker := time.NewTicker(l.watchRetryInterval)
defer ticker.Stop()
OUTER:
for {
// Check if the channel is already closed
select {
case <-l.stopCh:
l.backend.logger.Debug("WatchLock():Stop lock signaled/closed.")
break OUTER
default:
}
// Check if we've exceeded retries
if retries >= l.watchRetryMax-1 {
l.backend.logger.Debug("WatchLock: Failed to get lock data from object storage. Giving up the lease after max retries")
break OUTER
}
// Wait for the timer
select {
case <-ticker.C:
case <-l.stopCh:
break OUTER
}
lockRecordCache := loadLockRecordCache(l)
if (lockRecordCache == nil) ||
(lockRecordCache.lockRecord == nil) ||
(lockRecordCache.lockRecord.Identity != l.identity) ||
(time.Now().Sub(lockRecordCache.lastUpdate) > l.ttl) {
l.backend.logger.Debug("WatchLock: Lock record cache is nil, stale or does not belong to self.")
break OUTER
}
lockRecord, _, err := l.get(context.Background())
if err != nil {
retries++
l.backend.logger.Debug("WatchLock: Failed to get lock data from object storage. Retrying..")
metrics.SetGauge(metricHaWatchLockRetriable, 1)
continue
}
if (lockRecord == nil) || (lockRecord.Identity != l.identity) {
l.backend.logger.Debug("WatchLock: Lock record cache is nil or does not belong to self.")
break OUTER
}
// reset retries counter on success
retries = 0
l.backend.logger.Debug("WatchLock() successful")
metrics.SetGauge(metricHaWatchLockRetriable, 0)
}
l.stopLock.Lock()
defer l.stopLock.Unlock()
if !l.stopped {
l.stopped = true
l.backend.logger.Debug("Closing the stop channel to give up leadership.")
close(l.stopCh)
}
}
func (l *Lock) Unlock() error {
l.backend.logger.Debug("Unlock() called")
defer metrics.MeasureSince(metricLockUnlock, time.Now().UTC())
l.internalLock.Lock()
defer l.internalLock.Unlock()
if !l.held {
return nil
}
// Stop any existing locking or renewal attempts
l.stopLock.Lock()
if !l.stopped {
l.stopped = true
close(l.stopCh)
}
l.stopLock.Unlock()
// We are no longer holding the lock
l.held = false
// Get current lock record
currentLockRecord, etag, err := l.get(context.Background())
if err != nil {
return fmt.Errorf("error reading lock record: %w", err)
}
if currentLockRecord != nil && currentLockRecord.Identity == l.identity {
defer metrics.MeasureSince(metricDeleteHa, time.Now())
opcClientRequestId, err := uuid.GenerateUUID()
if err != nil {
l.backend.logger.Debug("Unlock: error generating UUID")
return fmt.Errorf("failed to generate UUID: %w", err)
}
l.backend.logger.Debug("Unlock", "opc-client-request-id", opcClientRequestId)
request := objectstorage.DeleteObjectRequest{
NamespaceName: &l.backend.namespaceName,
BucketName: &l.backend.lockBucketName,
ObjectName: &l.key,
IfMatch: &etag,
OpcClientRequestId: &opcClientRequestId,
}
response, err := l.backend.client.DeleteObject(context.Background(), request)
l.backend.logRequest("deleteHA", response.RawResponse, response.OpcClientRequestId, response.OpcRequestId, err)
if err != nil {
metrics.IncrCounter(metricDeleteFailed, 1)
return fmt.Errorf("write lock: %w", err)
}
}
return nil
}
func (l *Lock) Value() (bool, string, error) {
l.backend.logger.Debug("Value() called")
defer metrics.MeasureSince(metricLockValue, time.Now().UTC())
lockRecord, _, err := l.get(context.Background())
if err != nil {
return false, "", err
}
if lockRecord == nil {
return false, "", err
}
return true, lockRecord.Value, nil
}
// get retrieves the Value for the lock.
func (l *Lock) get(ctx context.Context) (*LockRecord, string, error) {
l.backend.logger.Debug("Called getLockRecord()")
// Read lock Key
defer metrics.MeasureSince(metricGetHa, time.Now())
opcClientRequestId, err := uuid.GenerateUUID()
if err != nil {
l.backend.logger.Error("getHa: error generating UUID")
return nil, "", fmt.Errorf("failed to generate UUID: %w", err)
}
l.backend.logger.Debug("getHa", "opc-client-request-id", opcClientRequestId)
request := objectstorage.GetObjectRequest{
NamespaceName: &l.backend.namespaceName,
BucketName: &l.backend.lockBucketName,
ObjectName: &l.key,
OpcClientRequestId: &opcClientRequestId,
}
ctx, cancel := context.WithTimeout(ctx, ObjectStorageCallsReadTimeout)
defer cancel()
response, err := l.backend.client.GetObject(ctx, request)
l.backend.logRequest("getHA", response.RawResponse, response.OpcClientRequestId, response.OpcRequestId, err)
if err != nil {
if response.RawResponse != nil && response.RawResponse.StatusCode == http.StatusNotFound {
return nil, "", nil
}
metrics.IncrCounter(metricGetFailed, 1)
l.backend.logger.Error("Error calling GET", "err", err)
return nil, "", fmt.Errorf("failed to read Value for %q: %w", l.key, err)
}
defer response.RawResponse.Body.Close()
body, err := ioutil.ReadAll(response.Content)
if err != nil {
metrics.IncrCounter(metricGetFailed, 1)
l.backend.logger.Error("Error reading content", "err", err)
return nil, "", fmt.Errorf("failed to decode Value into bytes: %w", err)
}
var lockRecord LockRecord
err = json.Unmarshal(body, &lockRecord)
if err != nil {
metrics.IncrCounter(metricGetFailed, 1)
l.backend.logger.Error("Error un-marshalling content", "err", err)
return nil, "", fmt.Errorf("failed to read Value for %q: %w", l.key, err)
}
return &lockRecord, *response.ETag, nil
}
func (l *Lock) writeLock() (bool, error) {
l.backend.logger.Debug("WriteLock() called")
// Create a transaction to read and the update (maybe)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// The transaction will be retried, and it could sit in a queue behind, say,
// the delete operation. To stop the transaction, we close the context when
// the associated stopCh is received.
go func() {
select {
case <-l.stopCh:
cancel()
case <-ctx.Done():
}
}()
lockRecordCache := loadLockRecordCache(l)
if (lockRecordCache == nil) || lockRecordCache.lockRecord == nil ||
lockRecordCache.lockRecord.Identity != l.identity ||
time.Now().Sub(lockRecordCache.lastUpdate) > l.ttl {
// case secondary
currentLockRecord, currentEtag, err := l.get(ctx)
if err != nil {
return false, fmt.Errorf("error reading lock record: %w", err)
}
if (lockRecordCache == nil) || lockRecordCache.etag != currentEtag {
// update cached lock record
l.lockRecordCache.Store(&LockCache{
etag: currentEtag,
lastUpdate: time.Now().UTC(),
lockRecord: currentLockRecord,
})
lockRecordCache = loadLockRecordCache(l)
}
// Current lock record being null implies that there is no leader. In this case we want to try acquiring lock.
if currentLockRecord != nil && time.Now().Sub(lockRecordCache.lastUpdate) < LockCacheMinAcceptableAge {
return false, nil
}
// cache is old enough and current, try acquiring lock as secondary
}
newLockRecord := &LockRecord{
Key: l.key,
Value: l.value,
Identity: l.identity,
}
newLockRecordJson, err := json.Marshal(newLockRecord)
if err != nil {
return false, fmt.Errorf("error reading lock record: %w", err)
}
defer metrics.MeasureSince(metricPutHa, time.Now())
opcClientRequestId, err := uuid.GenerateUUID()
if err != nil {
l.backend.logger.Error("putHa: error generating UUID")
return false, fmt.Errorf("failed to generate UUID: %w", err)
}
l.backend.logger.Debug("putHa", "opc-client-request-id", opcClientRequestId)
size := int64(len(newLockRecordJson))
putRequest := objectstorage.PutObjectRequest{
NamespaceName: &l.backend.namespaceName,
BucketName: &l.backend.lockBucketName,
ObjectName: &l.key,
ContentLength: &size,
PutObjectBody: ioutil.NopCloser(bytes.NewReader(newLockRecordJson)),
OpcMeta: nil,
OpcClientRequestId: &opcClientRequestId,
}
if lockRecordCache.etag == "" {
noneMatch := "*"
putRequest.IfNoneMatch = &noneMatch
} else {
putRequest.IfMatch = &lockRecordCache.etag
}
newtEtag := ""
for i := 1; i <= LockWriteRetriesOnFailures; i++ {
writeCtx, writeCancel := context.WithTimeout(ctx, ObjectStorageCallsWriteTimeout)
defer writeCancel()
putObjectResponse, putObjectError := l.backend.client.PutObject(writeCtx, putRequest)
l.backend.logRequest("putHA", putObjectResponse.RawResponse, putObjectResponse.OpcClientRequestId, putObjectResponse.OpcRequestId, putObjectError)
if putObjectError == nil {
newtEtag = *putObjectResponse.ETag
putObjectResponse.RawResponse.Body.Close()
break
}
err = putObjectError
if putObjectResponse.RawResponse == nil {
metrics.IncrCounter(metricPutFailed, 1)
l.backend.logger.Error("PUT", "err", err)
break
}
putObjectResponse.RawResponse.Body.Close()
// Retry if the return code is 5xx
if (putObjectResponse.RawResponse.StatusCode / 100) == 5 {
metrics.IncrCounter(metricPutFailed, 1)
l.backend.logger.Warn("PUT. Retrying..", "err", err)
time.Sleep(time.Duration(100*i) * time.Millisecond)
} else {
l.backend.logger.Error("PUT", "err", err)
break
}
}
if err != nil {
return false, fmt.Errorf("write lock: %w", err)
}
l.backend.logger.Debug("Lock written", string(newLockRecordJson))
l.lockRecordCache.Store(&LockCache{
etag: newtEtag,
lastUpdate: time.Now().UTC(),
lockRecord: newLockRecord,
})
metrics.SetGauge(metricLeaderValue, 1)
return true, nil
}