Make lock2's retryInterval smaller so it grabs the lock as soon as lock1's renewer fails to renew in time. Fix the logic to test if lock1's leader channel gets closed: we don't need a goroutine, and the logic was broken in that if we timed out we'd never write to the blocking channel we then try to read from. Moreover the timeout was wrong.
358 lines
9.1 KiB
358 lines
9.1 KiB
package postgresql
import (
log "github.com/hashicorp/go-hclog"
_ "github.com/lib/pq"
func TestPostgreSQLBackend(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
// Use docker as pg backend if no url is provided via environment variables
var cleanup func()
connURL := os.Getenv("PGURL")
if connURL == "" {
cleanup, connURL = prepareTestContainer(t, logger)
defer cleanup()
table := os.Getenv("PGTABLE")
if table == "" {
table = "vault_kv_store"
hae := os.Getenv("PGHAENABLED")
if hae == "" {
hae = "true"
// Run vault tests
logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
b1, err := NewPostgreSQLBackend(map[string]string{
"connection_url": connURL,
"table": table,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
b2, err := NewPostgreSQLBackend(map[string]string{
"connection_url": connURL,
"table": table,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
pg := b1.(*PostgreSQLBackend)
// Read postgres version to test basic connects works
var pgversion string
if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
t.Fatalf("Failed to check for Postgres version: %v", err)
logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
setupDatabaseObjects(t, logger, pg)
defer func() {
pg := b1.(*PostgreSQLBackend)
_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
if err != nil {
t.Fatalf("Failed to truncate table: %v", err)
logger.Info("Running basic backend tests")
physical.ExerciseBackend(t, b1)
logger.Info("Running list prefix backend tests")
physical.ExerciseBackend_ListPrefix(t, b1)
ha1, ok := b1.(physical.HABackend)
if !ok {
t.Fatalf("PostgreSQLDB does not implement HABackend")
ha2, ok := b2.(physical.HABackend)
if !ok {
t.Fatalf("PostgreSQLDB does not implement HABackend")
if ha1.HAEnabled() && ha2.HAEnabled() {
logger.Info("Running ha backend tests")
physical.ExerciseHABackend(t, ha1, ha2)
testPostgresSQLLockTTL(t, ha1)
testPostgresSQLLockRenewal(t, ha1)
// Similar to testHABackend, but using internal implementation details to
// trigger the lock failure scenario by setting the lock renew period for one
// of the locks to a higher value than the lock TTL.
func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
// Set much smaller lock times to speed up the test.
lockTTL := 3
renewInterval := time.Second * 1
retryInterval := time.Second * 1
longRenewInterval := time.Duration(lockTTL*2) * time.Second
lockkey := "postgresttl"
var leaderCh <-chan struct{}
// Get the lock
origLock, err := ha.LockWith(lockkey, "bar")
if err != nil {
t.Fatalf("err: %v", err)
// set the first lock renew period to double the expected TTL.
lock := origLock.(*PostgreSQLLock)
lock.renewInterval = longRenewInterval
lock.ttlSeconds = lockTTL
// Attempt to lock
leaderCh, err = lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
if !held {
t.Fatalf("should be held")
if val != "bar" {
t.Fatalf("bad value: %v", val)
// Second acquisition should succeed because the first lock should
// not renew within the 3 sec TTL.
origLock2, err := ha.LockWith(lockkey, "baz")
if err != nil {
t.Fatalf("err: %v", err)
lock2 := origLock2.(*PostgreSQLLock)
lock2.renewInterval = renewInterval
lock2.ttlSeconds = lockTTL
lock2.retryInterval = retryInterval
// Cancel attempt in 6 sec so as not to block unit tests forever
stopCh := make(chan struct{})
time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
// Attempt to lock should work
leaderCh2, err := lock2.Lock(stopCh)
if err != nil {
t.Fatalf("err: %v", err)
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
defer lock2.Unlock()
// Check the value
held, val, err := lock2.Value()
if err != nil {
t.Fatalf("err: %v", err)
if !held {
t.Fatalf("should be held")
if val != "baz" {
t.Fatalf("bad value: %v", val)
// The first lock should have lost the leader channel
select {
case <-time.After(longRenewInterval * 2):
t.Fatalf("original lock did not have its leader channel closed.")
case <-leaderCh:
// Verify that once Unlock is called, we don't keep trying to renew the original
// lock.
func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
// Get the lock
origLock, err := ha.LockWith("pgrenewal", "bar")
if err != nil {
t.Fatalf("err: %v", err)
// customize the renewal and watch intervals
lock := origLock.(*PostgreSQLLock)
// lock.renewInterval = time.Second * 1
// Attempt to lock
leaderCh, err := lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
if !held {
t.Fatalf("should be held")
if val != "bar" {
t.Fatalf("bad value: %v", val)
// Release the lock, which will delete the stored item
if err := lock.Unlock(); err != nil {
t.Fatalf("err: %v", err)
// Wait longer than the renewal time
time.Sleep(1500 * time.Millisecond)
// Attempt to lock with new lock
newLock, err := ha.LockWith("pgrenewal", "baz")
if err != nil {
t.Fatalf("err: %v", err)
// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
stopCh := make(chan struct{})
timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
time.AfterFunc(timeout, func() {
t.Logf("giving up on lock attempt after %v", timeout)
// Attempt to lock should work
leaderCh2, err := newLock.Lock(stopCh)
if err != nil {
t.Fatalf("err: %v", err)
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
// Check the value
held, val, err = newLock.Value()
if err != nil {
t.Fatalf("err: %v", err)
if !held {
t.Fatalf("should be held")
if val != "baz" {
t.Fatalf("bad value: %v", val)
// Cleanup
func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retConnString string) {
// If environment variable is set, use this connectionstring without starting docker container
if os.Getenv("PGURL") != "" {
return func() {}, os.Getenv("PGURL")
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
// using 11.1 which is currently latest, use hard version for stability of tests
resource, err := pool.Run("postgres", "11.1", []string{})
if err != nil {
t.Fatalf("Could not start docker Postgres: %s", err)
retConnString = fmt.Sprintf("postgres://postgres@localhost:%v/postgres?sslmode=disable", resource.GetPort("5432/tcp"))
cleanup = func() {
docker.CleanupResource(t, pool, resource)
// Provide a test function to the pool to test if docker instance service is up.
// We try to setup a pg backend as test for successful connect
// exponential backoff-retry, because the dockerinstance may not be able to accept
// connections yet, test by trying to setup a postgres backend, max-timeout is 60s
if err := pool.Retry(func() error {
var err error
_, err = NewPostgreSQLBackend(map[string]string{
"connection_url": retConnString,
}, logger)
return err
}); err != nil {
t.Fatalf("Could not connect to docker: %s", err)
return cleanup, retConnString
func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
var err error
// Setup tables and indexes if not exists.
createTableSQL := fmt.Sprintf(
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
" path TEXT COLLATE \"C\", "+
" key TEXT COLLATE \"C\", "+
" value BYTEA, "+
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
" ); ", pg.table)
_, err = pg.client.Exec(createTableSQL)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
_, err = pg.client.Exec(createIndexSQL)
if err != nil {
t.Fatalf("Failed to create index: %v", err)
createHaTableSQL :=
" CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
" ha_key TEXT COLLATE \"C\" NOT NULL, " +
" ha_identity TEXT COLLATE \"C\" NOT NULL, " +
" ha_value TEXT COLLATE \"C\", " +
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
" ); "
_, err = pg.client.Exec(createHaTableSQL)
if err != nil {
t.Fatalf("Failed to create hatable: %v", err)