2cf5009ff5
Make lock2's retryInterval smaller so it grabs the lock as soon as lock1's renewer fails to renew in time. Fix the logic to test if lock1's leader channel gets closed: we don't need a goroutine, and the logic was broken in that if we timed out we'd never write to the blocking channel we then try to read from. Moreover the timeout was wrong.
358 lines
9.1 KiB
Go
358 lines
9.1 KiB
Go
package postgresql
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
"github.com/ory/dockertest"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
func TestPostgreSQLBackend(t *testing.T) {
|
|
logger := logging.NewVaultLogger(log.Debug)
|
|
|
|
// Use docker as pg backend if no url is provided via environment variables
|
|
var cleanup func()
|
|
connURL := os.Getenv("PGURL")
|
|
if connURL == "" {
|
|
cleanup, connURL = prepareTestContainer(t, logger)
|
|
defer cleanup()
|
|
}
|
|
|
|
table := os.Getenv("PGTABLE")
|
|
if table == "" {
|
|
table = "vault_kv_store"
|
|
}
|
|
|
|
hae := os.Getenv("PGHAENABLED")
|
|
if hae == "" {
|
|
hae = "true"
|
|
}
|
|
|
|
// Run vault tests
|
|
logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
|
|
|
|
b1, err := NewPostgreSQLBackend(map[string]string{
|
|
"connection_url": connURL,
|
|
"table": table,
|
|
"ha_enabled": hae,
|
|
}, logger)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
|
|
b2, err := NewPostgreSQLBackend(map[string]string{
|
|
"connection_url": connURL,
|
|
"table": table,
|
|
"ha_enabled": hae,
|
|
}, logger)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Failed to create new backend: %v", err)
|
|
}
|
|
pg := b1.(*PostgreSQLBackend)
|
|
|
|
// Read postgres version to test basic connects works
|
|
var pgversion string
|
|
if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
|
|
t.Fatalf("Failed to check for Postgres version: %v", err)
|
|
}
|
|
logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
|
|
|
|
setupDatabaseObjects(t, logger, pg)
|
|
|
|
defer func() {
|
|
pg := b1.(*PostgreSQLBackend)
|
|
_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
|
|
if err != nil {
|
|
t.Fatalf("Failed to truncate table: %v", err)
|
|
}
|
|
}()
|
|
|
|
logger.Info("Running basic backend tests")
|
|
physical.ExerciseBackend(t, b1)
|
|
logger.Info("Running list prefix backend tests")
|
|
physical.ExerciseBackend_ListPrefix(t, b1)
|
|
|
|
ha1, ok := b1.(physical.HABackend)
|
|
if !ok {
|
|
t.Fatalf("PostgreSQLDB does not implement HABackend")
|
|
}
|
|
|
|
ha2, ok := b2.(physical.HABackend)
|
|
if !ok {
|
|
t.Fatalf("PostgreSQLDB does not implement HABackend")
|
|
}
|
|
|
|
if ha1.HAEnabled() && ha2.HAEnabled() {
|
|
logger.Info("Running ha backend tests")
|
|
physical.ExerciseHABackend(t, ha1, ha2)
|
|
testPostgresSQLLockTTL(t, ha1)
|
|
testPostgresSQLLockRenewal(t, ha1)
|
|
}
|
|
}
|
|
|
|
// Similar to testHABackend, but using internal implementation details to
|
|
// trigger the lock failure scenario by setting the lock renew period for one
|
|
// of the locks to a higher value than the lock TTL.
|
|
func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
|
|
// Set much smaller lock times to speed up the test.
|
|
lockTTL := 3
|
|
renewInterval := time.Second * 1
|
|
retryInterval := time.Second * 1
|
|
longRenewInterval := time.Duration(lockTTL*2) * time.Second
|
|
lockkey := "postgresttl"
|
|
|
|
var leaderCh <-chan struct{}
|
|
|
|
// Get the lock
|
|
origLock, err := ha.LockWith(lockkey, "bar")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
{
|
|
// set the first lock renew period to double the expected TTL.
|
|
lock := origLock.(*PostgreSQLLock)
|
|
lock.renewInterval = longRenewInterval
|
|
lock.ttlSeconds = lockTTL
|
|
|
|
// Attempt to lock
|
|
leaderCh, err = lock.Lock(nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh == nil {
|
|
t.Fatalf("failed to get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err := lock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "bar" {
|
|
t.Fatalf("bad value: %v", val)
|
|
}
|
|
}
|
|
|
|
// Second acquisition should succeed because the first lock should
|
|
// not renew within the 3 sec TTL.
|
|
origLock2, err := ha.LockWith(lockkey, "baz")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
{
|
|
lock2 := origLock2.(*PostgreSQLLock)
|
|
lock2.renewInterval = renewInterval
|
|
lock2.ttlSeconds = lockTTL
|
|
lock2.retryInterval = retryInterval
|
|
|
|
// Cancel attempt in 6 sec so as not to block unit tests forever
|
|
stopCh := make(chan struct{})
|
|
time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
|
|
close(stopCh)
|
|
})
|
|
|
|
// Attempt to lock should work
|
|
leaderCh2, err := lock2.Lock(stopCh)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh2 == nil {
|
|
t.Fatalf("should get leader ch")
|
|
}
|
|
defer lock2.Unlock()
|
|
|
|
// Check the value
|
|
held, val, err := lock2.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "baz" {
|
|
t.Fatalf("bad value: %v", val)
|
|
}
|
|
}
|
|
|
|
// The first lock should have lost the leader channel
|
|
select {
|
|
case <-time.After(longRenewInterval * 2):
|
|
t.Fatalf("original lock did not have its leader channel closed.")
|
|
case <-leaderCh:
|
|
}
|
|
}
|
|
|
|
// Verify that once Unlock is called, we don't keep trying to renew the original
|
|
// lock.
|
|
func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
|
|
// Get the lock
|
|
origLock, err := ha.LockWith("pgrenewal", "bar")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
// customize the renewal and watch intervals
|
|
lock := origLock.(*PostgreSQLLock)
|
|
// lock.renewInterval = time.Second * 1
|
|
|
|
// Attempt to lock
|
|
leaderCh, err := lock.Lock(nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh == nil {
|
|
t.Fatalf("failed to get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err := lock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "bar" {
|
|
t.Fatalf("bad value: %v", val)
|
|
}
|
|
|
|
// Release the lock, which will delete the stored item
|
|
if err := lock.Unlock(); err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
// Wait longer than the renewal time
|
|
time.Sleep(1500 * time.Millisecond)
|
|
|
|
// Attempt to lock with new lock
|
|
newLock, err := ha.LockWith("pgrenewal", "baz")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
|
|
stopCh := make(chan struct{})
|
|
timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
|
|
time.AfterFunc(timeout, func() {
|
|
t.Logf("giving up on lock attempt after %v", timeout)
|
|
close(stopCh)
|
|
})
|
|
|
|
// Attempt to lock should work
|
|
leaderCh2, err := newLock.Lock(stopCh)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if leaderCh2 == nil {
|
|
t.Fatalf("should get leader ch")
|
|
}
|
|
|
|
// Check the value
|
|
held, val, err = newLock.Value()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if !held {
|
|
t.Fatalf("should be held")
|
|
}
|
|
if val != "baz" {
|
|
t.Fatalf("bad value: %v", val)
|
|
}
|
|
|
|
// Cleanup
|
|
newLock.Unlock()
|
|
}
|
|
|
|
func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retConnString string) {
|
|
// If environment variable is set, use this connectionstring without starting docker container
|
|
if os.Getenv("PGURL") != "" {
|
|
return func() {}, os.Getenv("PGURL")
|
|
}
|
|
|
|
pool, err := dockertest.NewPool("")
|
|
if err != nil {
|
|
t.Fatalf("Failed to connect to docker: %s", err)
|
|
}
|
|
// using 11.1 which is currently latest, use hard version for stability of tests
|
|
resource, err := pool.Run("postgres", "11.1", []string{})
|
|
if err != nil {
|
|
t.Fatalf("Could not start docker Postgres: %s", err)
|
|
}
|
|
|
|
retConnString = fmt.Sprintf("postgres://postgres@localhost:%v/postgres?sslmode=disable", resource.GetPort("5432/tcp"))
|
|
|
|
cleanup = func() {
|
|
docker.CleanupResource(t, pool, resource)
|
|
}
|
|
|
|
// Provide a test function to the pool to test if docker instance service is up.
|
|
// We try to setup a pg backend as test for successful connect
|
|
// exponential backoff-retry, because the dockerinstance may not be able to accept
|
|
// connections yet, test by trying to setup a postgres backend, max-timeout is 60s
|
|
if err := pool.Retry(func() error {
|
|
var err error
|
|
_, err = NewPostgreSQLBackend(map[string]string{
|
|
"connection_url": retConnString,
|
|
}, logger)
|
|
return err
|
|
|
|
}); err != nil {
|
|
cleanup()
|
|
t.Fatalf("Could not connect to docker: %s", err)
|
|
}
|
|
|
|
return cleanup, retConnString
|
|
}
|
|
|
|
func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
|
|
var err error
|
|
// Setup tables and indexes if not exists.
|
|
createTableSQL := fmt.Sprintf(
|
|
" CREATE TABLE IF NOT EXISTS %v ( "+
|
|
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
|
|
" path TEXT COLLATE \"C\", "+
|
|
" key TEXT COLLATE \"C\", "+
|
|
" value BYTEA, "+
|
|
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
|
|
" ); ", pg.table)
|
|
|
|
_, err = pg.client.Exec(createTableSQL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %v", err)
|
|
}
|
|
|
|
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
|
|
|
|
_, err = pg.client.Exec(createIndexSQL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create index: %v", err)
|
|
}
|
|
|
|
createHaTableSQL :=
|
|
" CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
|
|
" ha_key TEXT COLLATE \"C\" NOT NULL, " +
|
|
" ha_identity TEXT COLLATE \"C\" NOT NULL, " +
|
|
" ha_value TEXT COLLATE \"C\", " +
|
|
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
|
|
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
|
|
" ); "
|
|
|
|
_, err = pg.client.Exec(createHaTableSQL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create hatable: %v", err)
|
|
}
|
|
}
|