open-vault/vault/ui_test.go

192 lines
4.8 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"context"
"testing"
"github.com/hashicorp/vault/sdk/logical"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical/inmem"
)
func TestConfig_Enabled(t *testing.T) {
logger := logging.NewVaultLogger(log.Trace)
phys, err := inmem.NewTransactionalInmem(nil, logger)
if err != nil {
t.Fatal(err)
}
logl := &logical.InmemStorage{}
config := NewUIConfig(true, phys, logl)
if !config.Enabled() {
t.Fatal("ui should be enabled")
}
config = NewUIConfig(false, phys, logl)
if config.Enabled() {
t.Fatal("ui should not be enabled")
}
}
func TestConfig_Headers(t *testing.T) {
logger := logging.NewVaultLogger(log.Trace)
phys, err := inmem.NewTransactionalInmem(nil, logger)
if err != nil {
t.Fatal(err)
}
logl := &logical.InmemStorage{}
config := NewUIConfig(true, phys, logl)
headers, err := config.Headers(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(headers) != len(config.defaultHeaders) {
t.Fatalf("expected %d headers, got %d", len(config.defaultHeaders), len(headers))
}
head, err := config.GetHeader(context.Background(), "Test-Header")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(head) != 0 {
t.Fatal("header returned found, should not be found")
}
err = config.SetHeader(context.Background(), "Test-Header", []string{"123", "456"})
if err != nil {
t.Fatalf("err: %v", err)
}
head, err = config.GetHeader(context.Background(), "Test-Header")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(head) != 2 {
t.Fatalf("header not found or incorrect number of values: %#v", head)
}
if head[0] != "123" {
t.Fatalf("expected: %s, got: %s", "123", head[0])
}
if head[1] != "456" {
t.Fatalf("expected: %s, got: %s", "456", head[1])
}
head, err = config.GetHeader(context.Background(), "tEST-hEADER")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(head) != 2 {
t.Fatalf("header not found or incorrect number of values: %#v", head)
}
if head[0] != "123" {
t.Fatalf("expected: %s, got: %s", "123", head[0])
}
if head[1] != "456" {
t.Fatalf("expected: %s, got: %s", "456", head[1])
}
keys, err := config.HeaderKeys(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
err = config.SetHeader(context.Background(), "Test-Header-2", []string{"321"})
if err != nil {
t.Fatalf("err: %v", err)
}
keys, err = config.HeaderKeys(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(keys) != 2 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
err = config.DeleteHeader(context.Background(), "Test-Header-2")
if err != nil {
t.Fatalf("err: %v", err)
}
err = config.DeleteHeader(context.Background(), "Test-Header")
if err != nil {
t.Fatalf("err: %v", err)
}
head, err = config.GetHeader(context.Background(), "Test-Header")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(head) != 0 {
t.Fatal("header returned found, should not be found")
}
keys, err = config.HeaderKeys(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(keys) != 0 {
t.Fatalf("expected 0 key, got %d", len(keys))
}
}
func TestConfig_DefaultHeaders(t *testing.T) {
logger := logging.NewVaultLogger(log.Trace)
phys, err := inmem.NewTransactionalInmem(nil, logger)
if err != nil {
t.Fatal(err)
}
logl := &logical.InmemStorage{}
config := NewUIConfig(true, phys, logl)
headers, err := config.Headers(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
if len(headers) != len(config.defaultHeaders) {
t.Fatalf("expected %d headers, got %d", len(config.defaultHeaders), len(headers))
}
headers, err = config.Headers(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
defaultCSP := config.defaultHeaders.Get("Content-security-Policy")
head := headers.Get("Content-Security-Policy")
if head != defaultCSP {
t.Fatalf("header does not match: expected %s, got %s", defaultCSP, head)
}
err = config.SetHeader(context.Background(), "Content-security-Policy", []string{"test"})
if err != nil {
t.Fatalf("err: %v", err)
}
headers, err = config.Headers(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
head = headers.Get("Content-Security-Policy")
if head != "test" {
t.Fatalf("header does not match: expected %s, got %s", "test", head)
}
err = config.DeleteHeader(context.Background(), "Content-Security-Policy")
if err != nil {
t.Fatalf("err: %v", err)
}
headers, err = config.Headers(context.Background())
if err != nil {
t.Fatalf("err: %v", err)
}
head = headers.Get("Content-Security-Policy")
if err != nil {
t.Fatalf("err: %v", err)
}
if head != defaultCSP {
t.Fatalf("header does not match: expected %s, got %s", defaultCSP, head)
}
}