open-vault/sdk/database/dbplugin/v5/middleware_test.go
Hamid Ghaf 27bb03bbc0
adding copyright header (#19555)
* adding copyright header

* fix fmt and a test
2023-03-15 09:00:52 -07:00

488 lines
14 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbplugin
import (
"context"
"errors"
"net/url"
"reflect"
"testing"
"github.com/hashicorp/go-hclog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestDatabaseErrorSanitizerMiddleware(t *testing.T) {
type testCase struct {
inputErr error
secretsFunc func() map[string]string
expectedError error
}
tests := map[string]testCase{
"nil error": {
inputErr: nil,
expectedError: nil,
},
"url error": {
inputErr: new(url.Error),
expectedError: errors.New("unable to parse connection url"),
},
"nil secrets func": {
inputErr: errors.New("here is my password: iofsd9473tg"),
expectedError: errors.New("here is my password: iofsd9473tg"),
},
"secrets with empty string": {
inputErr: errors.New("here is my password: iofsd9473tg"),
secretsFunc: secretFunc(t, "", ""),
expectedError: errors.New("here is my password: iofsd9473tg"),
},
"secrets that do not match": {
inputErr: errors.New("here is my password: iofsd9473tg"),
secretsFunc: secretFunc(t, "asdf", "<redacted>"),
expectedError: errors.New("here is my password: iofsd9473tg"),
},
"secrets that do match": {
inputErr: errors.New("here is my password: iofsd9473tg"),
secretsFunc: secretFunc(t, "iofsd9473tg", "<redacted>"),
expectedError: errors.New("here is my password: <redacted>"),
},
"multiple secrets": {
inputErr: errors.New("here is my password: iofsd9473tg"),
secretsFunc: secretFunc(t,
"iofsd9473tg", "<redacted>",
"password", "<this was the word password>",
),
expectedError: errors.New("here is my <this was the word password>: <redacted>"),
},
"gRPC status error": {
inputErr: status.Error(codes.InvalidArgument, "an error with a password iofsd9473tg"),
secretsFunc: secretFunc(t, "iofsd9473tg", "<redacted>"),
expectedError: status.Errorf(codes.InvalidArgument, "an error with a password <redacted>"),
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
db := fakeDatabase{}
mw := NewDatabaseErrorSanitizerMiddleware(db, test.secretsFunc)
actualErr := mw.sanitize(test.inputErr)
if !reflect.DeepEqual(actualErr, test.expectedError) {
t.Fatalf("Actual error: %s\nExpected error: %s", actualErr, test.expectedError)
}
})
}
t.Run("Initialize", func(t *testing.T) {
db := &recordingDatabase{
next: fakeDatabase{
initErr: errors.New("password: iofsd9473tg with some stuff after it"),
},
}
mw := DatabaseErrorSanitizerMiddleware{
next: db,
secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
}
expectedErr := errors.New("password: <redacted> with some stuff after it")
_, err := mw.Initialize(context.Background(), InitializeRequest{})
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
}
assertEquals(t, db.initializeCalls, 1)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("NewUser", func(t *testing.T) {
db := &recordingDatabase{
next: fakeDatabase{
newUserErr: errors.New("password: iofsd9473tg with some stuff after it"),
},
}
mw := DatabaseErrorSanitizerMiddleware{
next: db,
secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
}
expectedErr := errors.New("password: <redacted> with some stuff after it")
_, err := mw.NewUser(context.Background(), NewUserRequest{})
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 1)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("UpdateUser", func(t *testing.T) {
db := &recordingDatabase{
next: fakeDatabase{
updateUserErr: errors.New("password: iofsd9473tg with some stuff after it"),
},
}
mw := DatabaseErrorSanitizerMiddleware{
next: db,
secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
}
expectedErr := errors.New("password: <redacted> with some stuff after it")
_, err := mw.UpdateUser(context.Background(), UpdateUserRequest{})
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 1)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("DeleteUser", func(t *testing.T) {
db := &recordingDatabase{
next: fakeDatabase{
deleteUserErr: errors.New("password: iofsd9473tg with some stuff after it"),
},
}
mw := DatabaseErrorSanitizerMiddleware{
next: db,
secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
}
expectedErr := errors.New("password: <redacted> with some stuff after it")
_, err := mw.DeleteUser(context.Background(), DeleteUserRequest{})
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 1)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("Type", func(t *testing.T) {
db := &recordingDatabase{
next: fakeDatabase{
typeErr: errors.New("password: iofsd9473tg with some stuff after it"),
},
}
mw := DatabaseErrorSanitizerMiddleware{
next: db,
secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
}
expectedErr := errors.New("password: <redacted> with some stuff after it")
_, err := mw.Type()
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 1)
assertEquals(t, db.closeCalls, 0)
})
t.Run("Close", func(t *testing.T) {
db := &recordingDatabase{
next: fakeDatabase{
closeErr: errors.New("password: iofsd9473tg with some stuff after it"),
},
}
mw := DatabaseErrorSanitizerMiddleware{
next: db,
secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
}
expectedErr := errors.New("password: <redacted> with some stuff after it")
err := mw.Close()
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 1)
})
}
func secretFunc(t *testing.T, vals ...string) func() map[string]string {
t.Helper()
if len(vals)%2 != 0 {
t.Fatalf("Test configuration error: secretFunc must be called with an even number of values")
}
m := map[string]string{}
for i := 0; i < len(vals); i += 2 {
key := vals[i]
m[key] = vals[i+1]
}
return func() map[string]string {
return m
}
}
func TestTracingMiddleware(t *testing.T) {
t.Run("Initialize", func(t *testing.T) {
db := &recordingDatabase{}
logger := hclog.NewNullLogger()
mw := databaseTracingMiddleware{
next: db,
logger: logger,
}
_, err := mw.Initialize(context.Background(), InitializeRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 1)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("NewUser", func(t *testing.T) {
db := &recordingDatabase{}
logger := hclog.NewNullLogger()
mw := databaseTracingMiddleware{
next: db,
logger: logger,
}
_, err := mw.NewUser(context.Background(), NewUserRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 1)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("UpdateUser", func(t *testing.T) {
db := &recordingDatabase{}
logger := hclog.NewNullLogger()
mw := databaseTracingMiddleware{
next: db,
logger: logger,
}
_, err := mw.UpdateUser(context.Background(), UpdateUserRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 1)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("DeleteUser", func(t *testing.T) {
db := &recordingDatabase{}
logger := hclog.NewNullLogger()
mw := databaseTracingMiddleware{
next: db,
logger: logger,
}
_, err := mw.DeleteUser(context.Background(), DeleteUserRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 1)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("Type", func(t *testing.T) {
db := &recordingDatabase{}
logger := hclog.NewNullLogger()
mw := databaseTracingMiddleware{
next: db,
logger: logger,
}
_, err := mw.Type()
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 1)
assertEquals(t, db.closeCalls, 0)
})
t.Run("Close", func(t *testing.T) {
db := &recordingDatabase{}
logger := hclog.NewNullLogger()
mw := databaseTracingMiddleware{
next: db,
logger: logger,
}
err := mw.Close()
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 1)
})
}
func TestMetricsMiddleware(t *testing.T) {
t.Run("Initialize", func(t *testing.T) {
db := &recordingDatabase{}
mw := databaseMetricsMiddleware{
next: db,
typeStr: "metrics",
}
_, err := mw.Initialize(context.Background(), InitializeRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 1)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("NewUser", func(t *testing.T) {
db := &recordingDatabase{}
mw := databaseMetricsMiddleware{
next: db,
typeStr: "metrics",
}
_, err := mw.NewUser(context.Background(), NewUserRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 1)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("UpdateUser", func(t *testing.T) {
db := &recordingDatabase{}
mw := databaseMetricsMiddleware{
next: db,
typeStr: "metrics",
}
_, err := mw.UpdateUser(context.Background(), UpdateUserRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 1)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("DeleteUser", func(t *testing.T) {
db := &recordingDatabase{}
mw := databaseMetricsMiddleware{
next: db,
typeStr: "metrics",
}
_, err := mw.DeleteUser(context.Background(), DeleteUserRequest{})
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 1)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 0)
})
t.Run("Type", func(t *testing.T) {
db := &recordingDatabase{}
mw := databaseMetricsMiddleware{
next: db,
typeStr: "metrics",
}
_, err := mw.Type()
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 1)
assertEquals(t, db.closeCalls, 0)
})
t.Run("Close", func(t *testing.T) {
db := &recordingDatabase{}
mw := databaseMetricsMiddleware{
next: db,
typeStr: "metrics",
}
err := mw.Close()
if err != nil {
t.Fatalf("Expected no error, but got: %s", err)
}
assertEquals(t, db.initializeCalls, 0)
assertEquals(t, db.newUserCalls, 0)
assertEquals(t, db.updateUserCalls, 0)
assertEquals(t, db.deleteUserCalls, 0)
assertEquals(t, db.typeCalls, 0)
assertEquals(t, db.closeCalls, 1)
})
}
func assertEquals(t *testing.T, actual, expected int) {
t.Helper()
if actual != expected {
t.Fatalf("Actual: %d Expected: %d", actual, expected)
}
}