Support event subscriptions with glob wildcards (#19205)
This commit is contained in:
parent
ccab6ab676
commit
2472029a0a
|
@ -20,13 +20,13 @@ import (
|
|||
)
|
||||
|
||||
type eventSubscribeArgs struct {
|
||||
ctx context.Context
|
||||
logger hclog.Logger
|
||||
events *eventbus.EventBus
|
||||
ns *namespace.Namespace
|
||||
eventType logical.EventType
|
||||
conn *websocket.Conn
|
||||
json bool
|
||||
ctx context.Context
|
||||
logger hclog.Logger
|
||||
events *eventbus.EventBus
|
||||
ns *namespace.Namespace
|
||||
pattern string
|
||||
conn *websocket.Conn
|
||||
json bool
|
||||
}
|
||||
|
||||
// handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason
|
||||
|
@ -34,7 +34,7 @@ type eventSubscribeArgs struct {
|
|||
func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) {
|
||||
ctx := args.ctx
|
||||
logger := args.logger
|
||||
ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.eventType)
|
||||
ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.pattern)
|
||||
if err != nil {
|
||||
logger.Info("Error subscribing", "error", err)
|
||||
return websocket.StatusUnsupportedData, "Error subscribing", nil
|
||||
|
@ -97,12 +97,11 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
|
|||
if ns.ID != namespace.RootNamespaceID {
|
||||
prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path)
|
||||
}
|
||||
eventTypeStr := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix))
|
||||
if eventTypeStr == "" {
|
||||
pattern := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix))
|
||||
if pattern == "" {
|
||||
respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to"))
|
||||
return
|
||||
}
|
||||
eventType := logical.EventType(eventTypeStr)
|
||||
|
||||
json := false
|
||||
jsonRaw := r.URL.Query().Get("json")
|
||||
|
@ -135,7 +134,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
|
|||
}
|
||||
}()
|
||||
|
||||
closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, eventType, conn, json})
|
||||
closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, pattern, conn, json})
|
||||
if err != nil {
|
||||
closeStatus = websocket.CloseStatus(err)
|
||||
if closeStatus == -1 {
|
||||
|
|
|
@ -16,10 +16,17 @@ import (
|
|||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/ryanuber/go-glob"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const defaultTimeout = 60 * time.Second
|
||||
const (
|
||||
// eventTypeAll is purely internal to the event bus. We use it to send all
|
||||
// events down one big firehose, and pipelines define their own filtering
|
||||
// based on what each subscriber is interested in.
|
||||
eventTypeAll = "*"
|
||||
defaultTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotStarted = errors.New("event broker has not been started")
|
||||
|
@ -45,16 +52,14 @@ type pluginEventBus struct {
|
|||
|
||||
type asyncChanNode struct {
|
||||
// TODO: add bounded deque buffer of *EventReceived
|
||||
ctx context.Context
|
||||
ch chan *logical.EventReceived
|
||||
namespace *namespace.Namespace
|
||||
logger hclog.Logger
|
||||
ctx context.Context
|
||||
ch chan *logical.EventReceived
|
||||
logger hclog.Logger
|
||||
|
||||
// used to close the connection
|
||||
closeOnce sync.Once
|
||||
cancelFunc context.CancelFunc
|
||||
pipelineID eventlogger.PipelineID
|
||||
eventType eventlogger.EventType
|
||||
broker *eventlogger.Broker
|
||||
}
|
||||
|
||||
|
@ -97,7 +102,7 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace,
|
|||
// We can't easily know when the Send is complete, so we can't call the cancel function.
|
||||
// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
|
||||
ctx, _ = context.WithTimeout(ctx, bus.timeout)
|
||||
_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived)
|
||||
_, err := bus.broker.Send(ctx, eventTypeAll, eventReceived)
|
||||
if err != nil {
|
||||
// if no listeners for this event type are registered, that's okay, the event
|
||||
// will just not be sent anywhere
|
||||
|
@ -164,32 +169,42 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eventType logical.EventType) (<-chan *logical.EventReceived, context.CancelFunc, error) {
|
||||
func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pattern string) (<-chan *logical.EventReceived, context.CancelFunc, error) {
|
||||
// subscriptions are still stored even if the bus has not been started
|
||||
pipelineID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nodeID, err := uuid.GenerateUUID()
|
||||
filterNodeID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
filterNode := newFilterNode(ns, pattern)
|
||||
err = bus.broker.RegisterNode(eventlogger.NodeID(filterNodeID), filterNode)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sinkNodeID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// TODO: should we have just one node per namespace, and handle all the routing ourselves?
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
asyncNode := newAsyncNode(ctx, ns, bus.logger)
|
||||
err = bus.broker.RegisterNode(eventlogger.NodeID(nodeID), asyncNode)
|
||||
err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode)
|
||||
if err != nil {
|
||||
defer cancel()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nodes := []eventlogger.NodeID{bus.formatterNodeID, eventlogger.NodeID(nodeID)}
|
||||
nodes := []eventlogger.NodeID{eventlogger.NodeID(filterNodeID), bus.formatterNodeID, eventlogger.NodeID(sinkNodeID)}
|
||||
|
||||
pipeline := eventlogger.Pipeline{
|
||||
PipelineID: eventlogger.PipelineID(pipelineID),
|
||||
EventType: eventlogger.EventType(eventType),
|
||||
EventType: eventTypeAll,
|
||||
NodeIDs: nodes,
|
||||
}
|
||||
err = bus.broker.RegisterPipeline(pipeline)
|
||||
|
@ -197,10 +212,10 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve
|
|||
defer cancel()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
addSubscriptions(1)
|
||||
// add info needed to cancel the subscription
|
||||
asyncNode.pipelineID = eventlogger.PipelineID(pipelineID)
|
||||
asyncNode.eventType = eventlogger.EventType(eventType)
|
||||
asyncNode.cancelFunc = cancel
|
||||
return asyncNode.ch, asyncNode.Close, nil
|
||||
}
|
||||
|
@ -211,12 +226,32 @@ func (bus *EventBus) SetSendTimeout(timeout time.Duration) {
|
|||
bus.timeout = timeout
|
||||
}
|
||||
|
||||
func newFilterNode(ns *namespace.Namespace, pattern string) *eventlogger.Filter {
|
||||
return &eventlogger.Filter{
|
||||
Predicate: func(e *eventlogger.Event) (bool, error) {
|
||||
eventRecv := e.Payload.(*logical.EventReceived)
|
||||
|
||||
// Drop if event is not in our namespace.
|
||||
// TODO: add wildcard/child namespace processing here in some cases?
|
||||
if eventRecv.Namespace != ns.Path {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Filter for correct event type, including wildcards.
|
||||
if !glob.Glob(pattern, eventRecv.EventType) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode {
|
||||
return &asyncChanNode{
|
||||
ctx: ctx,
|
||||
ch: make(chan *logical.EventReceived),
|
||||
namespace: namespace,
|
||||
logger: logger,
|
||||
ctx: ctx,
|
||||
ch: make(chan *logical.EventReceived),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,7 +260,7 @@ func (node *asyncChanNode) Close() {
|
|||
node.closeOnce.Do(func() {
|
||||
defer node.cancelFunc()
|
||||
if node.broker != nil {
|
||||
err := node.broker.RemovePipeline(node.eventType, node.pipelineID)
|
||||
err := node.broker.RemovePipeline(eventTypeAll, node.pipelineID)
|
||||
if err != nil {
|
||||
node.logger.Warn("Error removing pipeline for closing node", "error", err)
|
||||
}
|
||||
|
@ -238,11 +273,6 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*
|
|||
// sends to the channel async in another goroutine
|
||||
go func() {
|
||||
eventRecv := e.Payload.(*logical.EventReceived)
|
||||
// drop if event is not in our namespace
|
||||
// TODO: add wildcard processing here in some cases?
|
||||
if eventRecv.Namespace != node.namespace.Path {
|
||||
return
|
||||
}
|
||||
var timeout bool
|
||||
select {
|
||||
case node.ch <- eventRecv:
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
@ -38,7 +39,7 @@ func TestBusBasics(t *testing.T) {
|
|||
t.Errorf("Expected no error sending: %v", err)
|
||||
}
|
||||
|
||||
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
|
||||
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -81,7 +82,7 @@ func TestNamespaceFiltering(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
|
||||
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -137,13 +138,13 @@ func TestBus2Subscriptions(t *testing.T) {
|
|||
eventType2 := logical.EventType("someType2")
|
||||
bus.Start()
|
||||
|
||||
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType1)
|
||||
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cancel1()
|
||||
|
||||
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType2)
|
||||
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ func TestBusSubscriptionsCancel(t *testing.T) {
|
|||
received := atomic.Int32{}
|
||||
|
||||
for i := 0; i < create; i++ {
|
||||
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType)
|
||||
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -297,3 +298,78 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
|
|||
}
|
||||
t.Error("Timeout waiting for condition")
|
||||
}
|
||||
|
||||
// TestBusWildcardSubscriptions tests that a single subscription can receive
|
||||
// multiple event types using * for glob patterns.
|
||||
func TestBusWildcardSubscriptions(t *testing.T) {
|
||||
bus, err := NewEventBus(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
fooEventType := logical.EventType("kv/foo")
|
||||
barEventType := logical.EventType("kv/bar")
|
||||
bus.Start()
|
||||
|
||||
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, "kv/*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cancel1()
|
||||
|
||||
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, "*/bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer cancel2()
|
||||
|
||||
event1, err := logical.NewEvent()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
event2, err := logical.NewEvent()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, barEventType, event2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
timeout := time.After(1 * time.Second)
|
||||
// Expect to receive both events on ch1, which subscribed to kv/*
|
||||
var ch1Seen []string
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case message := <-ch1:
|
||||
ch1Seen = append(ch1Seen, message.Event.ID())
|
||||
case <-timeout:
|
||||
t.Error("Timeout waiting for event1")
|
||||
}
|
||||
}
|
||||
if len(ch1Seen) != 2 {
|
||||
t.Errorf("Expected 2 events but got: %v", ch1Seen)
|
||||
} else {
|
||||
if !strutil.StrListContains(ch1Seen, event1.ID()) {
|
||||
t.Errorf("Did not find %s event1 ID in ch1seen", event1.ID())
|
||||
}
|
||||
if !strutil.StrListContains(ch1Seen, event2.ID()) {
|
||||
t.Errorf("Did not find %s event2 ID in ch1seen", event2.ID())
|
||||
}
|
||||
}
|
||||
// Expect to receive just kv/bar on ch2, which subscribed to */bar
|
||||
select {
|
||||
case message := <-ch2:
|
||||
if message.Event.ID() != event2.ID() {
|
||||
t.Errorf("Got unexpected message: %v", message)
|
||||
}
|
||||
case <-timeout:
|
||||
t.Error("Timeout waiting for event2")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ func TestCanSendEventsFromBuiltinPlugin(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, logical.EventType(eventType))
|
||||
ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, eventType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue