open-vault/vault/eventbus/bus_test.go

421 lines
10 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package eventbus
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
)
// TestBusBasics tests that basic event sending and subscribing function.
func TestBusBasics(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
eventType := logical.EventType("someType")
event, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != ErrNotStarted {
t.Errorf("Expected not started error but got: %v", err)
}
bus.Start()
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Errorf("Expected no error sending: %v", err)
}
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
defer cancel()
event, err = logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
timeout := time.After(1 * time.Second)
select {
case message := <-ch:
if message.Payload.(*logical.EventReceived).Event.Id != event.Id {
t.Errorf("Got unexpected message: %+v", message)
}
case <-timeout:
t.Error("Timeout waiting for message")
}
}
// TestBusIgnoresSendContext tests that the context is ignored when sending to an event,
// so that we do not give up too quickly.
func TestBusIgnoresSendContext(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
eventType := logical.EventType("someType")
event, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
bus.Start()
ch, subCancel, err := bus.Subscribe(context.Background(), namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
defer subCancel()
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Errorf("Expected no error sending: %v", err)
}
timeout := time.After(1 * time.Second)
select {
case message := <-ch:
if message.Payload.(*logical.EventReceived).Event.Id != event.Id {
t.Errorf("Got unexpected message: %+v", message)
}
case <-timeout:
t.Error("Timeout waiting for message")
}
}
// TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus.
func TestNamespaceFiltering(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
bus.Start()
ctx := context.Background()
eventType := logical.EventType("someType")
event, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
defer cancel()
event, err = logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, &namespace.Namespace{
ID: "abc",
Path: "/abc",
}, nil, eventType, event)
if err != nil {
t.Error(err)
}
timeout := time.After(100 * time.Millisecond)
select {
case <-ch:
t.Errorf("Got abc namespace message when root namespace was specified")
case <-timeout:
// okay
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
timeout = time.After(1 * time.Second)
select {
case message := <-ch:
if message.Payload.(*logical.EventReceived).Event.Id != event.Id {
t.Errorf("Got unexpected message %+v but was waiting for %+v", message, event)
}
case <-timeout:
t.Error("Timed out waiting for message")
}
}
// TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers.
func TestBus2Subscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
eventType1 := logical.EventType("someType1")
eventType2 := logical.EventType("someType2")
bus.Start()
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType1))
if err != nil {
t.Fatal(err)
}
defer cancel1()
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType2))
if err != nil {
t.Fatal(err)
}
defer cancel2()
event1, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
event2, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType2, event2)
if err != nil {
t.Error(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType1, event1)
if err != nil {
t.Error(err)
}
timeout := time.After(1 * time.Second)
select {
case message := <-ch1:
if message.Payload.(*logical.EventReceived).Event.Id != event1.Id {
t.Errorf("Got unexpected message: %v", message)
}
case <-timeout:
t.Error("Timeout waiting for event1")
}
select {
case message := <-ch2:
if message.Payload.(*logical.EventReceived).Event.Id != event2.Id {
t.Errorf("Got unexpected message: %v", message)
}
case <-timeout:
t.Error("Timeout waiting for event2")
}
}
// TestBusSubscriptionsCancel verifies that canceled subscriptions are cleaned up.
func TestBusSubscriptionsCancel(t *testing.T) {
testCases := []struct {
cancel bool
}{
{cancel: true},
{cancel: false},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) {
subscriptions.Store(0)
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if !tc.cancel {
// set the timeout very short to make the test faster if we aren't canceling explicitly
bus.SetSendTimeout(100 * time.Millisecond)
}
bus.Start()
// create and stop a bunch of subscriptions
const create = 100
const stop = 50
eventType := logical.EventType("someType")
var channels []<-chan *eventlogger.Event
var cancels []context.CancelFunc
stopped := atomic.Int32{}
received := atomic.Int32{}
for i := 0; i < create; i++ {
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil {
t.Fatal(err)
}
t.Cleanup(cancelFunc)
channels = append(channels, ch)
cancels = append(cancels, cancelFunc)
go func(i int32) {
<-ch // always receive one message
received.Add(1)
// continue receiving messages as long as are not stopped
for i < int32(stop) {
<-ch
received.Add(1)
}
if tc.cancel {
cancelFunc() // stop explicitly to unsubscribe
}
stopped.Add(1)
}(int32(i))
}
// check that all channels receive a message
event, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create) })
waitFor(t, 1*time.Second, func() bool { return stopped.Load() == int32(stop) })
// send another message, but half should stop receiving
event, err = logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, eventType, event)
if err != nil {
t.Error(err)
}
waitFor(t, 1*time.Second, func() bool { return received.Load() == int32(create*2-stop) })
// the sends should time out and the subscriptions should drop when cancelFunc is called or the context cancels
waitFor(t, 1*time.Second, func() bool { return subscriptions.Load() == int64(create-stop) })
})
}
}
// waitFor waits for a condition to be true, up to the maximum timeout.
// It waits with a capped exponential backoff starting at 1ms.
// It is guaranteed to try f() at least once.
func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
t.Helper()
start := time.Now()
if f() {
return
}
sleepAmount := 1 * time.Millisecond
for time.Now().Sub(start) <= maxWait {
left := time.Now().Sub(start)
sleepAmount = sleepAmount * 2
if sleepAmount > left {
sleepAmount = left
}
time.Sleep(sleepAmount)
if f() {
return
}
}
t.Error("Timeout waiting for condition")
}
// TestBusWildcardSubscriptions tests that a single subscription can receive
// multiple event types using * for glob patterns.
func TestBusWildcardSubscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
fooEventType := logical.EventType("kv/foo")
barEventType := logical.EventType("kv/bar")
bus.Start()
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, "kv/*")
if err != nil {
t.Fatal(err)
}
defer cancel1()
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, "*/bar")
if err != nil {
t.Fatal(err)
}
defer cancel2()
event1, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
event2, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, barEventType, event2)
if err != nil {
t.Error(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1)
if err != nil {
t.Error(err)
}
timeout := time.After(1 * time.Second)
// Expect to receive both events on ch1, which subscribed to kv/*
var ch1Seen []string
for i := 0; i < 2; i++ {
select {
case message := <-ch1:
ch1Seen = append(ch1Seen, message.Payload.(*logical.EventReceived).Event.Id)
case <-timeout:
t.Error("Timeout waiting for event1")
}
}
if len(ch1Seen) != 2 {
t.Errorf("Expected 2 events but got: %v", ch1Seen)
} else {
if !strutil.StrListContains(ch1Seen, event1.Id) {
t.Errorf("Did not find %s event1 ID in ch1seen", event1.Id)
}
if !strutil.StrListContains(ch1Seen, event2.Id) {
t.Errorf("Did not find %s event2 ID in ch1seen", event2.Id)
}
}
// Expect to receive just kv/bar on ch2, which subscribed to */bar
select {
case message := <-ch2:
if message.Payload.(*logical.EventReceived).Event.Id != event2.Id {
t.Errorf("Got unexpected message: %v", message)
}
case <-timeout:
t.Error("Timeout waiting for event2")
}
}