open-vault/vault/activity_log_util_common_test.go

344 lines
13 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"
"github.com/axiomhq/hyperloglog"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/vault/activity"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
// Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal creates 3 months of hyperloglogs and fills them with
// overlapping clients. The test calls computeCurrentMonthForBillingPeriodInternal with the current month map having
// some overlap with the previous months. The test then verifies that the results have the correct number of entity and
// non-entity clients. The test also calls computeCurrentMonthForBillingPeriodInternal with an empty current month map,
// and verifies that the results are all 0.
func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T) {
// populate the first month with clients 1-10
monthOneHLL := hyperloglog.New()
// populate the second month with clients 5-15
monthTwoHLL := hyperloglog.New()
// populate the third month with clients 10-20
monthThreeHLL := hyperloglog.New()
for i := 0; i < 20; i++ {
clientID := []byte(fmt.Sprintf("client_%d", i))
if i < 10 {
monthOneHLL.Insert(clientID)
}
if 5 <= i && i < 15 {
monthTwoHLL.Insert(clientID)
}
if 10 <= i && i < 20 {
monthThreeHLL.Insert(clientID)
}
}
mockHLLGetFunc := func(ctx context.Context, startTime time.Time) (*hyperloglog.Sketch, error) {
currMonthStart := timeutil.StartOfMonth(time.Now())
if startTime.Equal(timeutil.MonthsPreviousTo(3, currMonthStart)) {
return monthThreeHLL, nil
}
if startTime.Equal(timeutil.MonthsPreviousTo(2, currMonthStart)) {
return monthTwoHLL, nil
}
if startTime.Equal(timeutil.MonthsPreviousTo(1, currMonthStart)) {
return monthOneHLL, nil
}
return nil, fmt.Errorf("bad start time")
}
// Let's add 2 entities exclusive to month 1 (clients 0,1),
// 2 entities shared by month 1 and 2 (clients 5,6),
// 2 entities shared by month 2 and 3 (clients 10,11), and
// 2 entities exclusive to month 3 (15,16). Furthermore, we can add
// 3 new entities (clients 20,21, and 22).
entitiesStruct := make(map[string]struct{}, 0)
entitiesStruct["client_0"] = struct{}{}
entitiesStruct["client_1"] = struct{}{}
entitiesStruct["client_5"] = struct{}{}
entitiesStruct["client_6"] = struct{}{}
entitiesStruct["client_10"] = struct{}{}
entitiesStruct["client_11"] = struct{}{}
entitiesStruct["client_15"] = struct{}{}
entitiesStruct["client_16"] = struct{}{}
entitiesStruct["client_20"] = struct{}{}
entitiesStruct["client_21"] = struct{}{}
entitiesStruct["client_22"] = struct{}{}
// We will add 3 nonentity clients from month 1 (clients 2,3,4),
// 3 shared by months 1 and 2 (7,8,9),
// 3 shared by months 2 and 3 (12,13,14), and
// 3 exclusive to month 3 (17,18,19). We will also
// add 4 new nonentity clients.
nonEntitiesStruct := make(map[string]struct{}, 0)
nonEntitiesStruct["client_2"] = struct{}{}
nonEntitiesStruct["client_3"] = struct{}{}
nonEntitiesStruct["client_4"] = struct{}{}
nonEntitiesStruct["client_7"] = struct{}{}
nonEntitiesStruct["client_8"] = struct{}{}
nonEntitiesStruct["client_9"] = struct{}{}
nonEntitiesStruct["client_12"] = struct{}{}
nonEntitiesStruct["client_13"] = struct{}{}
nonEntitiesStruct["client_14"] = struct{}{}
nonEntitiesStruct["client_17"] = struct{}{}
nonEntitiesStruct["client_18"] = struct{}{}
nonEntitiesStruct["client_19"] = struct{}{}
nonEntitiesStruct["client_23"] = struct{}{}
nonEntitiesStruct["client_24"] = struct{}{}
nonEntitiesStruct["client_25"] = struct{}{}
nonEntitiesStruct["client_26"] = struct{}{}
counts := &processCounts{
Entities: entitiesStruct,
NonEntities: nonEntitiesStruct,
}
currentMonthClientsMap := make(map[int64]*processMonth, 1)
currentMonthClients := &processMonth{
Counts: counts,
NewClients: &processNewClients{Counts: counts},
}
// Technially I think currentMonthClientsMap should have the keys as
// unix timestamps, but for the purposes of the unit test it doesn't
// matter what the values actually are.
currentMonthClientsMap[0] = currentMonthClients
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
endTime := timeutil.StartOfMonth(time.Now())
startTime := timeutil.MonthsPreviousTo(3, endTime)
monthRecord, err := a.computeCurrentMonthForBillingPeriodInternal(context.Background(), currentMonthClientsMap, mockHLLGetFunc, startTime, endTime)
if err != nil {
t.Fatal(err)
}
// We should have 11 entity clients and 16 nonentity clients, and 3 new entity clients
// and 4 new nonentity clients
if monthRecord.Counts.EntityClients != 11 {
t.Fatalf("wrong number of entity clients. Expected 11, got %d", monthRecord.Counts.EntityClients)
}
if monthRecord.Counts.NonEntityClients != 16 {
t.Fatalf("wrong number of non entity clients. Expected 16, got %d", monthRecord.Counts.NonEntityClients)
}
if monthRecord.NewClients.Counts.EntityClients != 3 {
t.Fatalf("wrong number of new entity clients. Expected 3, got %d", monthRecord.NewClients.Counts.EntityClients)
}
if monthRecord.NewClients.Counts.NonEntityClients != 4 {
t.Fatalf("wrong number of new non entity clients. Expected 4, got %d", monthRecord.NewClients.Counts.NonEntityClients)
}
// Attempt to compute current month when no records exist
endTime = time.Now().UTC()
startTime = timeutil.StartOfMonth(endTime)
emptyClientsMap := make(map[int64]*processMonth, 0)
monthRecord, err = a.computeCurrentMonthForBillingPeriodInternal(context.Background(), emptyClientsMap, mockHLLGetFunc, startTime, endTime)
if err != nil {
t.Fatalf("failed to compute empty current month, err: %v", err)
}
// We should have 0 entity clients, nonentity clients,new entity clients
// and new nonentity clients
if monthRecord.Counts.EntityClients != 0 {
t.Fatalf("wrong number of entity clients. Expected 0, got %d", monthRecord.Counts.EntityClients)
}
if monthRecord.Counts.NonEntityClients != 0 {
t.Fatalf("wrong number of non entity clients. Expected 0, got %d", monthRecord.Counts.NonEntityClients)
}
if monthRecord.NewClients.Counts.EntityClients != 0 {
t.Fatalf("wrong number of new entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.EntityClients)
}
if monthRecord.NewClients.Counts.NonEntityClients != 0 {
t.Fatalf("wrong number of new non entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.NonEntityClients)
}
}
// writeEntitySegment writes a single segment file with the given time and index for an entity
func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem)
}
// writeTokenSegment writes a single segment file with the given time and index for a token
func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, ts, index), protoItem)
}
// makeSegmentPath formats the path for a segment at a particular time and index
func makeSegmentPath(t *testing.T, typ string, ts time.Time, index int) string {
t.Helper()
return fmt.Sprintf("%s%s%d/%d", ActivityPrefix, typ, ts.Unix(), index)
}
// TestSegmentFileReader_BadData verifies that the reader returns errors when the data is unable to be parsed
// However, the next time that Read*() is called, the reader should still progress and be able to then return any
// valid data without errors
func TestSegmentFileReader_BadData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write bad data that won't be able to be unmarshaled at index 0
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, 0), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data"))
// write entity at index 1
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 1, entity)
// write token at index 1
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 1, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
// first the bad entity is read, which returns an error
_, err = reader.ReadEntity(context.Background())
require.Error(t, err)
// then, the reader can read the good entity at index 1
gotEntity, err := reader.ReadEntity(context.Background())
require.True(t, proto.Equal(gotEntity, entity))
require.Nil(t, err)
// the bad token causes an error
_, err = reader.ReadToken(context.Background())
require.Error(t, err)
// but the good token is able to be read
gotToken, err := reader.ReadToken(context.Background())
require.True(t, proto.Equal(gotToken, token))
require.Nil(t, err)
}
// TestSegmentFileReader_MissingData verifies that the segment file reader will skip over missing segment paths without
// errorring until it is able to find a valid segment path
func TestSegmentFileReader_MissingData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write entities and tokens at indexes 0, 1, 2
for i := 0; i < 3; i++ {
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, i), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data"))
}
// write entity at index 3
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 3, entity)
// write token at index 3
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 3, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
// delete the indexes 0, 1, 2
for i := 0; i < 3; i++ {
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenBasePath, now, i)))
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i)))
}
// we expect the reader to only return the data at index 3, and then be done
gotEntity, err := reader.ReadEntity(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotEntity, entity))
_, err = reader.ReadEntity(context.Background())
require.Equal(t, err, io.EOF)
gotToken, err := reader.ReadToken(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotToken, token))
_, err = reader.ReadToken(context.Background())
require.Equal(t, err, io.EOF)
}
// TestSegmentFileReader_NoData verifies that the reader return io.EOF when there is no data
func TestSegmentFileReader_NoData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
entity, err := reader.ReadEntity(context.Background())
require.Nil(t, entity)
require.Equal(t, err, io.EOF)
token, err := reader.ReadToken(context.Background())
require.Nil(t, token)
require.Equal(t, err, io.EOF)
}
// TestSegmentFileReader verifies that the reader iterates through all segments paths in ascending order and returns
// io.EOF when it's done
func TestSegmentFileReader(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
entities := make([]*activity.EntityActivityLog, 0, 3)
tokens := make([]*activity.TokenCount, 0, 3)
// write 3 entity segment pieces and 3 token segment pieces
for i := 0; i < 3; i++ {
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: fmt.Sprintf("id-%d", i),
},
}}
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
fmt.Sprintf("ns-%d", i): uint64(i),
}}
writeEntitySegment(t, core, now, i, entity)
writeTokenSegment(t, core, now, i, token)
entities = append(entities, entity)
tokens = append(tokens, token)
}
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
gotEntities := make([]*activity.EntityActivityLog, 0, 3)
gotTokens := make([]*activity.TokenCount, 0, 3)
// read the entities from the reader
for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) {
require.NoError(t, err)
gotEntities = append(gotEntities, entity)
}
// read the tokens from the reader
for token, err := reader.ReadToken(context.Background()); !errors.Is(err, io.EOF); token, err = reader.ReadToken(context.Background()) {
require.NoError(t, err)
gotTokens = append(gotTokens, token)
}
require.Len(t, gotEntities, 3)
require.Len(t, gotTokens, 3)
// verify that the entities and tokens we got from the reader are correct
// we can't use require.Equals() here because there are protobuf differences in unexported fields
for i := 0; i < 3; i++ {
require.True(t, proto.Equal(gotEntities[i], entities[i]))
require.True(t, proto.Equal(gotTokens[i], tokens[i]))
}
}