package postgresql import ( "fmt" "os" "testing" "time" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/testhelpers/docker" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/physical" "github.com/ory/dockertest" _ "github.com/lib/pq" ) func TestPostgreSQLBackend(t *testing.T) { logger := logging.NewVaultLogger(log.Debug) // Use docker as pg backend if no url is provided via environment variables var cleanup func() connURL := os.Getenv("PGURL") if connURL == "" { cleanup, connURL = prepareTestContainer(t, logger) defer cleanup() } table := os.Getenv("PGTABLE") if table == "" { table = "vault_kv_store" } hae := os.Getenv("PGHAENABLED") if hae == "" { hae = "true" } // Run vault tests logger.Info(fmt.Sprintf("Connection URL: %v", connURL)) b1, err := NewPostgreSQLBackend(map[string]string{ "connection_url": connURL, "table": table, "ha_enabled": hae, }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) } b2, err := NewPostgreSQLBackend(map[string]string{ "connection_url": connURL, "table": table, "ha_enabled": hae, }, logger) if err != nil { t.Fatalf("Failed to create new backend: %v", err) } pg := b1.(*PostgreSQLBackend) // Read postgres version to test basic connects works var pgversion string if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil { t.Fatalf("Failed to check for Postgres version: %v", err) } logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion)) setupDatabaseObjects(t, logger, pg) defer func() { pg := b1.(*PostgreSQLBackend) _, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table)) if err != nil { t.Fatalf("Failed to truncate table: %v", err) } }() logger.Info("Running basic backend tests") physical.ExerciseBackend(t, b1) logger.Info("Running list prefix backend tests") physical.ExerciseBackend_ListPrefix(t, b1) ha1, ok := b1.(physical.HABackend) if !ok { t.Fatalf("PostgreSQLDB does not implement HABackend") } ha2, ok := b2.(physical.HABackend) if !ok { t.Fatalf("PostgreSQLDB does not implement HABackend") } if ha1.HAEnabled() && ha2.HAEnabled() { logger.Info("Running ha backend tests") physical.ExerciseHABackend(t, ha1, ha2) testPostgresSQLLockTTL(t, ha1) testPostgresSQLLockRenewal(t, ha1) } } func TestPostgreSQLBackendMaxIdleConnectionsParameter(t *testing.T) { _, err := NewPostgreSQLBackend(map[string]string{ "connection_url": "some connection url", "max_idle_connections": "bad param", }, logging.NewVaultLogger(log.Debug)) if err == nil { t.Error("Expected invalid max_idle_connections param to return error") } expectedErrStr := "failed parsing max_idle_connections parameter: strconv.Atoi: parsing \"bad param\": invalid syntax" if err.Error() != expectedErrStr { t.Errorf("Expected: \"%s\" but found \"%s\"", expectedErrStr, err.Error()) } } // Similar to testHABackend, but using internal implementation details to // trigger the lock failure scenario by setting the lock renew period for one // of the locks to a higher value than the lock TTL. func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) { // Set much smaller lock times to speed up the test. lockTTL := 3 renewInterval := time.Second * 1 retryInterval := time.Second * 1 longRenewInterval := time.Duration(lockTTL*2) * time.Second lockkey := "postgresttl" var leaderCh <-chan struct{} // Get the lock origLock, err := ha.LockWith(lockkey, "bar") if err != nil { t.Fatalf("err: %v", err) } { // set the first lock renew period to double the expected TTL. lock := origLock.(*PostgreSQLLock) lock.renewInterval = longRenewInterval lock.ttlSeconds = lockTTL // Attempt to lock leaderCh, err = lock.Lock(nil) if err != nil { t.Fatalf("err: %v", err) } if leaderCh == nil { t.Fatalf("failed to get leader ch") } // Check the value held, val, err := lock.Value() if err != nil { t.Fatalf("err: %v", err) } if !held { t.Fatalf("should be held") } if val != "bar" { t.Fatalf("bad value: %v", val) } } // Second acquisition should succeed because the first lock should // not renew within the 3 sec TTL. origLock2, err := ha.LockWith(lockkey, "baz") if err != nil { t.Fatalf("err: %v", err) } { lock2 := origLock2.(*PostgreSQLLock) lock2.renewInterval = renewInterval lock2.ttlSeconds = lockTTL lock2.retryInterval = retryInterval // Cancel attempt in 6 sec so as not to block unit tests forever stopCh := make(chan struct{}) time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() { close(stopCh) }) // Attempt to lock should work leaderCh2, err := lock2.Lock(stopCh) if err != nil { t.Fatalf("err: %v", err) } if leaderCh2 == nil { t.Fatalf("should get leader ch") } defer lock2.Unlock() // Check the value held, val, err := lock2.Value() if err != nil { t.Fatalf("err: %v", err) } if !held { t.Fatalf("should be held") } if val != "baz" { t.Fatalf("bad value: %v", val) } } // The first lock should have lost the leader channel select { case <-time.After(longRenewInterval * 2): t.Fatalf("original lock did not have its leader channel closed.") case <-leaderCh: } } // Verify that once Unlock is called, we don't keep trying to renew the original // lock. func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) { // Get the lock origLock, err := ha.LockWith("pgrenewal", "bar") if err != nil { t.Fatalf("err: %v", err) } // customize the renewal and watch intervals lock := origLock.(*PostgreSQLLock) // lock.renewInterval = time.Second * 1 // Attempt to lock leaderCh, err := lock.Lock(nil) if err != nil { t.Fatalf("err: %v", err) } if leaderCh == nil { t.Fatalf("failed to get leader ch") } // Check the value held, val, err := lock.Value() if err != nil { t.Fatalf("err: %v", err) } if !held { t.Fatalf("should be held") } if val != "bar" { t.Fatalf("bad value: %v", val) } // Release the lock, which will delete the stored item if err := lock.Unlock(); err != nil { t.Fatalf("err: %v", err) } // Wait longer than the renewal time time.Sleep(1500 * time.Millisecond) // Attempt to lock with new lock newLock, err := ha.LockWith("pgrenewal", "baz") if err != nil { t.Fatalf("err: %v", err) } // Cancel attempt after lock ttl + 1s so as not to block unit tests forever stopCh := make(chan struct{}) timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second time.AfterFunc(timeout, func() { t.Logf("giving up on lock attempt after %v", timeout) close(stopCh) }) // Attempt to lock should work leaderCh2, err := newLock.Lock(stopCh) if err != nil { t.Fatalf("err: %v", err) } if leaderCh2 == nil { t.Fatalf("should get leader ch") } // Check the value held, val, err = newLock.Value() if err != nil { t.Fatalf("err: %v", err) } if !held { t.Fatalf("should be held") } if val != "baz" { t.Fatalf("bad value: %v", val) } // Cleanup newLock.Unlock() } func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retConnString string) { // If environment variable is set, use this connectionstring without starting docker container if os.Getenv("PGURL") != "" { return func() {}, os.Getenv("PGURL") } pool, err := dockertest.NewPool("") if err != nil { t.Fatalf("Failed to connect to docker: %s", err) } // using 11.1 which is currently latest, use hard version for stability of tests resource, err := pool.Run("postgres", "11.1", []string{}) if err != nil { t.Fatalf("Could not start docker Postgres: %s", err) } retConnString = fmt.Sprintf("postgres://postgres@localhost:%v/postgres?sslmode=disable", resource.GetPort("5432/tcp")) cleanup = func() { docker.CleanupResource(t, pool, resource) } // Provide a test function to the pool to test if docker instance service is up. // We try to setup a pg backend as test for successful connect // exponential backoff-retry, because the dockerinstance may not be able to accept // connections yet, test by trying to setup a postgres backend, max-timeout is 60s if err := pool.Retry(func() error { var err error _, err = NewPostgreSQLBackend(map[string]string{ "connection_url": retConnString, }, logger) return err }); err != nil { cleanup() t.Fatalf("Could not connect to docker: %s", err) } return cleanup, retConnString } func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) { var err error // Setup tables and indexes if not exists. createTableSQL := fmt.Sprintf( " CREATE TABLE IF NOT EXISTS %v ( "+ " parent_path TEXT COLLATE \"C\" NOT NULL, "+ " path TEXT COLLATE \"C\", "+ " key TEXT COLLATE \"C\", "+ " value BYTEA, "+ " CONSTRAINT pkey PRIMARY KEY (path, key) "+ " ); ", pg.table) _, err = pg.client.Exec(createTableSQL) if err != nil { t.Fatalf("Failed to create table: %v", err) } createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table) _, err = pg.client.Exec(createIndexSQL) if err != nil { t.Fatalf("Failed to create index: %v", err) } createHaTableSQL := " CREATE TABLE IF NOT EXISTS vault_ha_locks ( " + " ha_key TEXT COLLATE \"C\" NOT NULL, " + " ha_identity TEXT COLLATE \"C\" NOT NULL, " + " ha_value TEXT COLLATE \"C\", " + " valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " + " CONSTRAINT ha_key PRIMARY KEY (ha_key) " + " ); " _, err = pg.client.Exec(createHaTableSQL) if err != nil { t.Fatalf("Failed to create hatable: %v", err) } }