diff --git a/vault/activity_log_util_common.go b/vault/activity_log_util_common.go index a57f046a2..3ae891553 100644 --- a/vault/activity_log_util_common.go +++ b/vault/activity_log_util_common.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "io" "sort" "strings" "time" @@ -15,6 +16,7 @@ import ( "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/activity" + "google.golang.org/protobuf/proto" ) type HLLGetter func(ctx context.Context, startTime time.Time) (*hyperloglog.Sketch, error) @@ -288,3 +290,96 @@ func (a *ActivityLog) mountAccessorToMountPath(mountAccessor string) string { } return displayPath } + +type singleTypeSegmentReader struct { + basePath string + startTime time.Time + paths []string + currentPathIndex int + a *ActivityLog +} +type segmentReader struct { + tokens *singleTypeSegmentReader + entities *singleTypeSegmentReader +} + +// SegmentReader is an interface that provides methods to read tokens and entities in order +type SegmentReader interface { + ReadToken(ctx context.Context) (*activity.TokenCount, error) + ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) +} + +func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.Time) (SegmentReader, error) { + entities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityEntityBasePath) + if err != nil { + return nil, err + } + tokens, err := a.newSingleTypeSegmentReader(ctx, startTime, activityTokenBasePath) + if err != nil { + return nil, err + } + return &segmentReader{entities: entities, tokens: tokens}, nil +} + +func (a *ActivityLog) newSingleTypeSegmentReader(ctx context.Context, startTime time.Time, prefix string) (*singleTypeSegmentReader, error) { + basePath := prefix + fmt.Sprint(startTime.Unix()) + "/" + pathList, err := a.view.List(ctx, basePath) + if err != nil { + return nil, err + } + return &singleTypeSegmentReader{ + basePath: basePath, + startTime: startTime, + paths: pathList, + currentPathIndex: 0, + a: a, + }, nil +} + +func (s *singleTypeSegmentReader) nextValue(ctx context.Context, out proto.Message) error { + var raw *logical.StorageEntry + var path string + for raw == nil { + if s.currentPathIndex >= len(s.paths) { + return io.EOF + } + path = s.paths[s.currentPathIndex] + // increment the index to continue iterating for the next read call, even if an error occurs during this call + s.currentPathIndex++ + var err error + raw, err = s.a.view.Get(ctx, s.basePath+path) + if err != nil { + return err + } + if raw == nil { + s.a.logger.Warn("expected log segment file has been deleted", "startTime", s.startTime, "segmentPath", path) + } + } + err := proto.Unmarshal(raw.Value, out) + if err != nil { + return fmt.Errorf("unable to parse segment file %v%v: %w", s.basePath, path, err) + } + return nil +} + +// ReadToken reads a token from the segment +// If there is none available, then the error will be io.EOF +func (e *segmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, error) { + out := &activity.TokenCount{} + err := e.tokens.nextValue(ctx, out) + if err != nil { + return nil, err + } + return out, nil +} + +// ReadEntity reads an entity from the segment +// If there is none available, then the error will be io.EOF +func (e *segmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) { + out := &activity.EntityActivityLog{} + err := e.entities.nextValue(ctx, out) + if err != nil { + return nil, err + } + return out, nil +} diff --git a/vault/activity_log_util_common_test.go b/vault/activity_log_util_common_test.go index e40e3d6f6..817dbf398 100644 --- a/vault/activity_log_util_common_test.go +++ b/vault/activity_log_util_common_test.go @@ -5,12 +5,17 @@ package vault import ( "context" + "errors" "fmt" + "io" "testing" "time" "github.com/axiomhq/hyperloglog" "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/vault/activity" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" ) // Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal creates 3 months of hyperloglogs and fills them with @@ -158,3 +163,181 @@ func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T) t.Fatalf("wrong number of new non entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.NonEntityClients) } } + +// writeEntitySegment writes a single segment file with the given time and index for an entity +func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) { + t.Helper() + protoItem, err := proto.Marshal(item) + require.NoError(t, err) + WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem) +} + +// writeTokenSegment writes a single segment file with the given time and index for a token +func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) { + t.Helper() + protoItem, err := proto.Marshal(item) + require.NoError(t, err) + WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, ts, index), protoItem) +} + +// makeSegmentPath formats the path for a segment at a particular time and index +func makeSegmentPath(t *testing.T, typ string, ts time.Time, index int) string { + t.Helper() + return fmt.Sprintf("%s%s%d/%d", ActivityPrefix, typ, ts.Unix(), index) +} + +// TestSegmentFileReader_BadData verifies that the reader returns errors when the data is unable to be parsed +// However, the next time that Read*() is called, the reader should still progress and be able to then return any +// valid data without errors +func TestSegmentFileReader_BadData(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + now := time.Now() + + // write bad data that won't be able to be unmarshaled at index 0 + WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, 0), []byte("fake data")) + WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data")) + + // write entity at index 1 + entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{ + { + ClientID: "id", + }, + }} + writeEntitySegment(t, core, now, 1, entity) + + // write token at index 1 + token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{ + "ns": 1, + }} + writeTokenSegment(t, core, now, 1, token) + reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) + require.NoError(t, err) + + // first the bad entity is read, which returns an error + _, err = reader.ReadEntity(context.Background()) + require.Error(t, err) + // then, the reader can read the good entity at index 1 + gotEntity, err := reader.ReadEntity(context.Background()) + require.True(t, proto.Equal(gotEntity, entity)) + require.Nil(t, err) + + // the bad token causes an error + _, err = reader.ReadToken(context.Background()) + require.Error(t, err) + // but the good token is able to be read + gotToken, err := reader.ReadToken(context.Background()) + require.True(t, proto.Equal(gotToken, token)) + require.Nil(t, err) +} + +// TestSegmentFileReader_MissingData verifies that the segment file reader will skip over missing segment paths without +// errorring until it is able to find a valid segment path +func TestSegmentFileReader_MissingData(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + now := time.Now() + // write entities and tokens at indexes 0, 1, 2 + for i := 0; i < 3; i++ { + WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, i), []byte("fake data")) + WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data")) + + } + // write entity at index 3 + entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{ + { + ClientID: "id", + }, + }} + writeEntitySegment(t, core, now, 3, entity) + // write token at index 3 + token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{ + "ns": 1, + }} + writeTokenSegment(t, core, now, 3, token) + reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) + require.NoError(t, err) + + // delete the indexes 0, 1, 2 + for i := 0; i < 3; i++ { + require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenBasePath, now, i))) + require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i))) + } + + // we expect the reader to only return the data at index 3, and then be done + gotEntity, err := reader.ReadEntity(context.Background()) + require.NoError(t, err) + require.True(t, proto.Equal(gotEntity, entity)) + _, err = reader.ReadEntity(context.Background()) + require.Equal(t, err, io.EOF) + + gotToken, err := reader.ReadToken(context.Background()) + require.NoError(t, err) + require.True(t, proto.Equal(gotToken, token)) + _, err = reader.ReadToken(context.Background()) + require.Equal(t, err, io.EOF) +} + +// TestSegmentFileReader_NoData verifies that the reader return io.EOF when there is no data +func TestSegmentFileReader_NoData(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + now := time.Now() + reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) + require.NoError(t, err) + entity, err := reader.ReadEntity(context.Background()) + require.Nil(t, entity) + require.Equal(t, err, io.EOF) + token, err := reader.ReadToken(context.Background()) + require.Nil(t, token) + require.Equal(t, err, io.EOF) +} + +// TestSegmentFileReader verifies that the reader iterates through all segments paths in ascending order and returns +// io.EOF when it's done +func TestSegmentFileReader(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + now := time.Now() + entities := make([]*activity.EntityActivityLog, 0, 3) + tokens := make([]*activity.TokenCount, 0, 3) + + // write 3 entity segment pieces and 3 token segment pieces + for i := 0; i < 3; i++ { + entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{ + { + ClientID: fmt.Sprintf("id-%d", i), + }, + }} + token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{ + fmt.Sprintf("ns-%d", i): uint64(i), + }} + writeEntitySegment(t, core, now, i, entity) + writeTokenSegment(t, core, now, i, token) + entities = append(entities, entity) + tokens = append(tokens, token) + } + + reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) + require.NoError(t, err) + + gotEntities := make([]*activity.EntityActivityLog, 0, 3) + gotTokens := make([]*activity.TokenCount, 0, 3) + + // read the entities from the reader + for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) { + require.NoError(t, err) + gotEntities = append(gotEntities, entity) + } + + // read the tokens from the reader + for token, err := reader.ReadToken(context.Background()); !errors.Is(err, io.EOF); token, err = reader.ReadToken(context.Background()) { + require.NoError(t, err) + gotTokens = append(gotTokens, token) + } + require.Len(t, gotEntities, 3) + require.Len(t, gotTokens, 3) + + // verify that the entities and tokens we got from the reader are correct + // we can't use require.Equals() here because there are protobuf differences in unexported fields + for i := 0; i < 3; i++ { + require.True(t, proto.Equal(gotEntities[i], entities[i])) + require.True(t, proto.Equal(gotTokens[i], tokens[i])) + } +}