VAULT-14733: SegmentReader interface for reading activity log segments (#19934)

* create a segment reader for activity log segment

* fix imports

* updates based on comments
This commit is contained in:
miagilepner 2023-04-06 16:23:41 +02:00 committed by GitHub
parent 8cd90fc1e2
commit 3b91b9ebbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 278 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import (
"context"
"errors"
"fmt"
"io"
"sort"
"strings"
"time"
@ -15,6 +16,7 @@ import (
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/activity"
"google.golang.org/protobuf/proto"
)
type HLLGetter func(ctx context.Context, startTime time.Time) (*hyperloglog.Sketch, error)
@ -288,3 +290,96 @@ func (a *ActivityLog) mountAccessorToMountPath(mountAccessor string) string {
}
return displayPath
}
type singleTypeSegmentReader struct {
basePath string
startTime time.Time
paths []string
currentPathIndex int
a *ActivityLog
}
type segmentReader struct {
tokens *singleTypeSegmentReader
entities *singleTypeSegmentReader
}
// SegmentReader is an interface that provides methods to read tokens and entities in order
type SegmentReader interface {
ReadToken(ctx context.Context) (*activity.TokenCount, error)
ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error)
}
func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.Time) (SegmentReader, error) {
entities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityEntityBasePath)
if err != nil {
return nil, err
}
tokens, err := a.newSingleTypeSegmentReader(ctx, startTime, activityTokenBasePath)
if err != nil {
return nil, err
}
return &segmentReader{entities: entities, tokens: tokens}, nil
}
func (a *ActivityLog) newSingleTypeSegmentReader(ctx context.Context, startTime time.Time, prefix string) (*singleTypeSegmentReader, error) {
basePath := prefix + fmt.Sprint(startTime.Unix()) + "/"
pathList, err := a.view.List(ctx, basePath)
if err != nil {
return nil, err
}
return &singleTypeSegmentReader{
basePath: basePath,
startTime: startTime,
paths: pathList,
currentPathIndex: 0,
a: a,
}, nil
}
func (s *singleTypeSegmentReader) nextValue(ctx context.Context, out proto.Message) error {
var raw *logical.StorageEntry
var path string
for raw == nil {
if s.currentPathIndex >= len(s.paths) {
return io.EOF
}
path = s.paths[s.currentPathIndex]
// increment the index to continue iterating for the next read call, even if an error occurs during this call
s.currentPathIndex++
var err error
raw, err = s.a.view.Get(ctx, s.basePath+path)
if err != nil {
return err
}
if raw == nil {
s.a.logger.Warn("expected log segment file has been deleted", "startTime", s.startTime, "segmentPath", path)
}
}
err := proto.Unmarshal(raw.Value, out)
if err != nil {
return fmt.Errorf("unable to parse segment file %v%v: %w", s.basePath, path, err)
}
return nil
}
// ReadToken reads a token from the segment
// If there is none available, then the error will be io.EOF
func (e *segmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, error) {
out := &activity.TokenCount{}
err := e.tokens.nextValue(ctx, out)
if err != nil {
return nil, err
}
return out, nil
}
// ReadEntity reads an entity from the segment
// If there is none available, then the error will be io.EOF
func (e *segmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) {
out := &activity.EntityActivityLog{}
err := e.entities.nextValue(ctx, out)
if err != nil {
return nil, err
}
return out, nil
}

View File

@ -5,12 +5,17 @@ package vault
import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"
"github.com/axiomhq/hyperloglog"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/vault/activity"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
// Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal creates 3 months of hyperloglogs and fills them with
@ -158,3 +163,181 @@ func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T)
t.Fatalf("wrong number of new non entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.NonEntityClients)
}
}
// writeEntitySegment writes a single segment file with the given time and index for an entity
func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem)
}
// writeTokenSegment writes a single segment file with the given time and index for a token
func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, ts, index), protoItem)
}
// makeSegmentPath formats the path for a segment at a particular time and index
func makeSegmentPath(t *testing.T, typ string, ts time.Time, index int) string {
t.Helper()
return fmt.Sprintf("%s%s%d/%d", ActivityPrefix, typ, ts.Unix(), index)
}
// TestSegmentFileReader_BadData verifies that the reader returns errors when the data is unable to be parsed
// However, the next time that Read*() is called, the reader should still progress and be able to then return any
// valid data without errors
func TestSegmentFileReader_BadData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write bad data that won't be able to be unmarshaled at index 0
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, 0), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data"))
// write entity at index 1
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 1, entity)
// write token at index 1
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 1, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
// first the bad entity is read, which returns an error
_, err = reader.ReadEntity(context.Background())
require.Error(t, err)
// then, the reader can read the good entity at index 1
gotEntity, err := reader.ReadEntity(context.Background())
require.True(t, proto.Equal(gotEntity, entity))
require.Nil(t, err)
// the bad token causes an error
_, err = reader.ReadToken(context.Background())
require.Error(t, err)
// but the good token is able to be read
gotToken, err := reader.ReadToken(context.Background())
require.True(t, proto.Equal(gotToken, token))
require.Nil(t, err)
}
// TestSegmentFileReader_MissingData verifies that the segment file reader will skip over missing segment paths without
// errorring until it is able to find a valid segment path
func TestSegmentFileReader_MissingData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write entities and tokens at indexes 0, 1, 2
for i := 0; i < 3; i++ {
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, i), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data"))
}
// write entity at index 3
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 3, entity)
// write token at index 3
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 3, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
// delete the indexes 0, 1, 2
for i := 0; i < 3; i++ {
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenBasePath, now, i)))
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i)))
}
// we expect the reader to only return the data at index 3, and then be done
gotEntity, err := reader.ReadEntity(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotEntity, entity))
_, err = reader.ReadEntity(context.Background())
require.Equal(t, err, io.EOF)
gotToken, err := reader.ReadToken(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotToken, token))
_, err = reader.ReadToken(context.Background())
require.Equal(t, err, io.EOF)
}
// TestSegmentFileReader_NoData verifies that the reader return io.EOF when there is no data
func TestSegmentFileReader_NoData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
entity, err := reader.ReadEntity(context.Background())
require.Nil(t, entity)
require.Equal(t, err, io.EOF)
token, err := reader.ReadToken(context.Background())
require.Nil(t, token)
require.Equal(t, err, io.EOF)
}
// TestSegmentFileReader verifies that the reader iterates through all segments paths in ascending order and returns
// io.EOF when it's done
func TestSegmentFileReader(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
entities := make([]*activity.EntityActivityLog, 0, 3)
tokens := make([]*activity.TokenCount, 0, 3)
// write 3 entity segment pieces and 3 token segment pieces
for i := 0; i < 3; i++ {
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: fmt.Sprintf("id-%d", i),
},
}}
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
fmt.Sprintf("ns-%d", i): uint64(i),
}}
writeEntitySegment(t, core, now, i, entity)
writeTokenSegment(t, core, now, i, token)
entities = append(entities, entity)
tokens = append(tokens, token)
}
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
gotEntities := make([]*activity.EntityActivityLog, 0, 3)
gotTokens := make([]*activity.TokenCount, 0, 3)
// read the entities from the reader
for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) {
require.NoError(t, err)
gotEntities = append(gotEntities, entity)
}
// read the tokens from the reader
for token, err := reader.ReadToken(context.Background()); !errors.Is(err, io.EOF); token, err = reader.ReadToken(context.Background()) {
require.NoError(t, err)
gotTokens = append(gotTokens, token)
}
require.Len(t, gotEntities, 3)
require.Len(t, gotTokens, 3)
// verify that the entities and tokens we got from the reader are correct
// we can't use require.Equals() here because there are protobuf differences in unexported fields
for i := 0; i < 3; i++ {
require.True(t, proto.Equal(gotEntities[i], entities[i]))
require.True(t, proto.Equal(gotTokens[i], tokens[i]))
}
}