2023-03-15 16:00:52 +00:00
|
|
|
// Copyright (c) HashiCorp, Inc.
|
|
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
|
2022-06-29 17:51:23 +00:00
|
|
|
package vault
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2023-04-06 14:23:41 +00:00
|
|
|
"errors"
|
2022-06-29 17:51:23 +00:00
|
|
|
"fmt"
|
2023-04-06 14:23:41 +00:00
|
|
|
"io"
|
2022-06-29 17:51:23 +00:00
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/axiomhq/hyperloglog"
|
|
|
|
"github.com/hashicorp/vault/helper/timeutil"
|
2023-04-06 14:23:41 +00:00
|
|
|
"github.com/hashicorp/vault/vault/activity"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"google.golang.org/protobuf/proto"
|
2022-06-29 17:51:23 +00:00
|
|
|
)
|
|
|
|
|
2023-03-06 12:08:22 +00:00
|
|
|
// Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal creates 3 months of hyperloglogs and fills them with
|
|
|
|
// overlapping clients. The test calls computeCurrentMonthForBillingPeriodInternal with the current month map having
|
|
|
|
// some overlap with the previous months. The test then verifies that the results have the correct number of entity and
|
|
|
|
// non-entity clients. The test also calls computeCurrentMonthForBillingPeriodInternal with an empty current month map,
|
|
|
|
// and verifies that the results are all 0.
|
2022-06-29 17:51:23 +00:00
|
|
|
func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T) {
|
|
|
|
// populate the first month with clients 1-10
|
|
|
|
monthOneHLL := hyperloglog.New()
|
|
|
|
// populate the second month with clients 5-15
|
|
|
|
monthTwoHLL := hyperloglog.New()
|
|
|
|
// populate the third month with clients 10-20
|
|
|
|
monthThreeHLL := hyperloglog.New()
|
|
|
|
|
|
|
|
for i := 0; i < 20; i++ {
|
|
|
|
clientID := []byte(fmt.Sprintf("client_%d", i))
|
|
|
|
if i < 10 {
|
|
|
|
monthOneHLL.Insert(clientID)
|
|
|
|
}
|
|
|
|
if 5 <= i && i < 15 {
|
|
|
|
monthTwoHLL.Insert(clientID)
|
|
|
|
}
|
|
|
|
if 10 <= i && i < 20 {
|
|
|
|
monthThreeHLL.Insert(clientID)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
mockHLLGetFunc := func(ctx context.Context, startTime time.Time) (*hyperloglog.Sketch, error) {
|
|
|
|
currMonthStart := timeutil.StartOfMonth(time.Now())
|
|
|
|
if startTime.Equal(timeutil.MonthsPreviousTo(3, currMonthStart)) {
|
|
|
|
return monthThreeHLL, nil
|
|
|
|
}
|
|
|
|
if startTime.Equal(timeutil.MonthsPreviousTo(2, currMonthStart)) {
|
|
|
|
return monthTwoHLL, nil
|
|
|
|
}
|
|
|
|
if startTime.Equal(timeutil.MonthsPreviousTo(1, currMonthStart)) {
|
|
|
|
return monthOneHLL, nil
|
|
|
|
}
|
|
|
|
return nil, fmt.Errorf("bad start time")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Let's add 2 entities exclusive to month 1 (clients 0,1),
|
|
|
|
// 2 entities shared by month 1 and 2 (clients 5,6),
|
|
|
|
// 2 entities shared by month 2 and 3 (clients 10,11), and
|
|
|
|
// 2 entities exclusive to month 3 (15,16). Furthermore, we can add
|
|
|
|
// 3 new entities (clients 20,21, and 22).
|
|
|
|
entitiesStruct := make(map[string]struct{}, 0)
|
|
|
|
entitiesStruct["client_0"] = struct{}{}
|
|
|
|
entitiesStruct["client_1"] = struct{}{}
|
|
|
|
entitiesStruct["client_5"] = struct{}{}
|
|
|
|
entitiesStruct["client_6"] = struct{}{}
|
|
|
|
entitiesStruct["client_10"] = struct{}{}
|
|
|
|
entitiesStruct["client_11"] = struct{}{}
|
|
|
|
entitiesStruct["client_15"] = struct{}{}
|
|
|
|
entitiesStruct["client_16"] = struct{}{}
|
|
|
|
entitiesStruct["client_20"] = struct{}{}
|
|
|
|
entitiesStruct["client_21"] = struct{}{}
|
|
|
|
entitiesStruct["client_22"] = struct{}{}
|
|
|
|
|
|
|
|
// We will add 3 nonentity clients from month 1 (clients 2,3,4),
|
|
|
|
// 3 shared by months 1 and 2 (7,8,9),
|
|
|
|
// 3 shared by months 2 and 3 (12,13,14), and
|
|
|
|
// 3 exclusive to month 3 (17,18,19). We will also
|
|
|
|
// add 4 new nonentity clients.
|
|
|
|
nonEntitiesStruct := make(map[string]struct{}, 0)
|
|
|
|
nonEntitiesStruct["client_2"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_3"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_4"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_7"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_8"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_9"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_12"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_13"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_14"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_17"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_18"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_19"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_23"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_24"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_25"] = struct{}{}
|
|
|
|
nonEntitiesStruct["client_26"] = struct{}{}
|
|
|
|
|
|
|
|
counts := &processCounts{
|
|
|
|
Entities: entitiesStruct,
|
|
|
|
NonEntities: nonEntitiesStruct,
|
|
|
|
}
|
|
|
|
|
|
|
|
currentMonthClientsMap := make(map[int64]*processMonth, 1)
|
|
|
|
currentMonthClients := &processMonth{
|
|
|
|
Counts: counts,
|
|
|
|
NewClients: &processNewClients{Counts: counts},
|
|
|
|
}
|
|
|
|
// Technially I think currentMonthClientsMap should have the keys as
|
|
|
|
// unix timestamps, but for the purposes of the unit test it doesn't
|
|
|
|
// matter what the values actually are.
|
|
|
|
currentMonthClientsMap[0] = currentMonthClients
|
|
|
|
|
|
|
|
core, _, _ := TestCoreUnsealed(t)
|
|
|
|
a := core.activityLog
|
|
|
|
|
|
|
|
endTime := timeutil.StartOfMonth(time.Now())
|
|
|
|
startTime := timeutil.MonthsPreviousTo(3, endTime)
|
|
|
|
|
|
|
|
monthRecord, err := a.computeCurrentMonthForBillingPeriodInternal(context.Background(), currentMonthClientsMap, mockHLLGetFunc, startTime, endTime)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// We should have 11 entity clients and 16 nonentity clients, and 3 new entity clients
|
|
|
|
// and 4 new nonentity clients
|
|
|
|
if monthRecord.Counts.EntityClients != 11 {
|
|
|
|
t.Fatalf("wrong number of entity clients. Expected 11, got %d", monthRecord.Counts.EntityClients)
|
|
|
|
}
|
|
|
|
if monthRecord.Counts.NonEntityClients != 16 {
|
|
|
|
t.Fatalf("wrong number of non entity clients. Expected 16, got %d", monthRecord.Counts.NonEntityClients)
|
|
|
|
}
|
|
|
|
if monthRecord.NewClients.Counts.EntityClients != 3 {
|
|
|
|
t.Fatalf("wrong number of new entity clients. Expected 3, got %d", monthRecord.NewClients.Counts.EntityClients)
|
|
|
|
}
|
|
|
|
if monthRecord.NewClients.Counts.NonEntityClients != 4 {
|
|
|
|
t.Fatalf("wrong number of new non entity clients. Expected 4, got %d", monthRecord.NewClients.Counts.NonEntityClients)
|
|
|
|
}
|
2022-10-07 16:37:09 +00:00
|
|
|
|
|
|
|
// Attempt to compute current month when no records exist
|
|
|
|
endTime = time.Now().UTC()
|
|
|
|
startTime = timeutil.StartOfMonth(endTime)
|
|
|
|
emptyClientsMap := make(map[int64]*processMonth, 0)
|
|
|
|
monthRecord, err = a.computeCurrentMonthForBillingPeriodInternal(context.Background(), emptyClientsMap, mockHLLGetFunc, startTime, endTime)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("failed to compute empty current month, err: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// We should have 0 entity clients, nonentity clients,new entity clients
|
|
|
|
// and new nonentity clients
|
|
|
|
if monthRecord.Counts.EntityClients != 0 {
|
|
|
|
t.Fatalf("wrong number of entity clients. Expected 0, got %d", monthRecord.Counts.EntityClients)
|
|
|
|
}
|
|
|
|
if monthRecord.Counts.NonEntityClients != 0 {
|
|
|
|
t.Fatalf("wrong number of non entity clients. Expected 0, got %d", monthRecord.Counts.NonEntityClients)
|
|
|
|
}
|
|
|
|
if monthRecord.NewClients.Counts.EntityClients != 0 {
|
|
|
|
t.Fatalf("wrong number of new entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.EntityClients)
|
|
|
|
}
|
|
|
|
if monthRecord.NewClients.Counts.NonEntityClients != 0 {
|
|
|
|
t.Fatalf("wrong number of new non entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.NonEntityClients)
|
|
|
|
}
|
2022-06-29 17:51:23 +00:00
|
|
|
}
|
2023-04-06 14:23:41 +00:00
|
|
|
|
|
|
|
// writeEntitySegment writes a single segment file with the given time and index for an entity
|
|
|
|
func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) {
|
|
|
|
t.Helper()
|
|
|
|
protoItem, err := proto.Marshal(item)
|
|
|
|
require.NoError(t, err)
|
|
|
|
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem)
|
|
|
|
}
|
|
|
|
|
|
|
|
// writeTokenSegment writes a single segment file with the given time and index for a token
|
|
|
|
func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) {
|
|
|
|
t.Helper()
|
|
|
|
protoItem, err := proto.Marshal(item)
|
|
|
|
require.NoError(t, err)
|
|
|
|
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, ts, index), protoItem)
|
|
|
|
}
|
|
|
|
|
|
|
|
// makeSegmentPath formats the path for a segment at a particular time and index
|
|
|
|
func makeSegmentPath(t *testing.T, typ string, ts time.Time, index int) string {
|
|
|
|
t.Helper()
|
|
|
|
return fmt.Sprintf("%s%s%d/%d", ActivityPrefix, typ, ts.Unix(), index)
|
|
|
|
}
|
|
|
|
|
|
|
|
// TestSegmentFileReader_BadData verifies that the reader returns errors when the data is unable to be parsed
|
|
|
|
// However, the next time that Read*() is called, the reader should still progress and be able to then return any
|
|
|
|
// valid data without errors
|
|
|
|
func TestSegmentFileReader_BadData(t *testing.T) {
|
|
|
|
core, _, _ := TestCoreUnsealed(t)
|
|
|
|
now := time.Now()
|
|
|
|
|
|
|
|
// write bad data that won't be able to be unmarshaled at index 0
|
|
|
|
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, 0), []byte("fake data"))
|
|
|
|
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data"))
|
|
|
|
|
|
|
|
// write entity at index 1
|
|
|
|
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
|
|
|
|
{
|
|
|
|
ClientID: "id",
|
|
|
|
},
|
|
|
|
}}
|
|
|
|
writeEntitySegment(t, core, now, 1, entity)
|
|
|
|
|
|
|
|
// write token at index 1
|
|
|
|
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
|
|
|
|
"ns": 1,
|
|
|
|
}}
|
|
|
|
writeTokenSegment(t, core, now, 1, token)
|
|
|
|
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
// first the bad entity is read, which returns an error
|
|
|
|
_, err = reader.ReadEntity(context.Background())
|
|
|
|
require.Error(t, err)
|
|
|
|
// then, the reader can read the good entity at index 1
|
|
|
|
gotEntity, err := reader.ReadEntity(context.Background())
|
|
|
|
require.True(t, proto.Equal(gotEntity, entity))
|
|
|
|
require.Nil(t, err)
|
|
|
|
|
|
|
|
// the bad token causes an error
|
|
|
|
_, err = reader.ReadToken(context.Background())
|
|
|
|
require.Error(t, err)
|
|
|
|
// but the good token is able to be read
|
|
|
|
gotToken, err := reader.ReadToken(context.Background())
|
|
|
|
require.True(t, proto.Equal(gotToken, token))
|
|
|
|
require.Nil(t, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// TestSegmentFileReader_MissingData verifies that the segment file reader will skip over missing segment paths without
|
|
|
|
// errorring until it is able to find a valid segment path
|
|
|
|
func TestSegmentFileReader_MissingData(t *testing.T) {
|
|
|
|
core, _, _ := TestCoreUnsealed(t)
|
|
|
|
now := time.Now()
|
|
|
|
// write entities and tokens at indexes 0, 1, 2
|
|
|
|
for i := 0; i < 3; i++ {
|
|
|
|
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, i), []byte("fake data"))
|
|
|
|
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data"))
|
|
|
|
|
|
|
|
}
|
|
|
|
// write entity at index 3
|
|
|
|
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
|
|
|
|
{
|
|
|
|
ClientID: "id",
|
|
|
|
},
|
|
|
|
}}
|
|
|
|
writeEntitySegment(t, core, now, 3, entity)
|
|
|
|
// write token at index 3
|
|
|
|
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
|
|
|
|
"ns": 1,
|
|
|
|
}}
|
|
|
|
writeTokenSegment(t, core, now, 3, token)
|
|
|
|
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
// delete the indexes 0, 1, 2
|
|
|
|
for i := 0; i < 3; i++ {
|
|
|
|
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenBasePath, now, i)))
|
|
|
|
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i)))
|
|
|
|
}
|
|
|
|
|
|
|
|
// we expect the reader to only return the data at index 3, and then be done
|
|
|
|
gotEntity, err := reader.ReadEntity(context.Background())
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, proto.Equal(gotEntity, entity))
|
|
|
|
_, err = reader.ReadEntity(context.Background())
|
|
|
|
require.Equal(t, err, io.EOF)
|
|
|
|
|
|
|
|
gotToken, err := reader.ReadToken(context.Background())
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, proto.Equal(gotToken, token))
|
|
|
|
_, err = reader.ReadToken(context.Background())
|
|
|
|
require.Equal(t, err, io.EOF)
|
|
|
|
}
|
|
|
|
|
|
|
|
// TestSegmentFileReader_NoData verifies that the reader return io.EOF when there is no data
|
|
|
|
func TestSegmentFileReader_NoData(t *testing.T) {
|
|
|
|
core, _, _ := TestCoreUnsealed(t)
|
|
|
|
now := time.Now()
|
|
|
|
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
|
|
|
|
require.NoError(t, err)
|
|
|
|
entity, err := reader.ReadEntity(context.Background())
|
|
|
|
require.Nil(t, entity)
|
|
|
|
require.Equal(t, err, io.EOF)
|
|
|
|
token, err := reader.ReadToken(context.Background())
|
|
|
|
require.Nil(t, token)
|
|
|
|
require.Equal(t, err, io.EOF)
|
|
|
|
}
|
|
|
|
|
|
|
|
// TestSegmentFileReader verifies that the reader iterates through all segments paths in ascending order and returns
|
|
|
|
// io.EOF when it's done
|
|
|
|
func TestSegmentFileReader(t *testing.T) {
|
|
|
|
core, _, _ := TestCoreUnsealed(t)
|
|
|
|
now := time.Now()
|
|
|
|
entities := make([]*activity.EntityActivityLog, 0, 3)
|
|
|
|
tokens := make([]*activity.TokenCount, 0, 3)
|
|
|
|
|
|
|
|
// write 3 entity segment pieces and 3 token segment pieces
|
|
|
|
for i := 0; i < 3; i++ {
|
|
|
|
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
|
|
|
|
{
|
|
|
|
ClientID: fmt.Sprintf("id-%d", i),
|
|
|
|
},
|
|
|
|
}}
|
|
|
|
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
|
|
|
|
fmt.Sprintf("ns-%d", i): uint64(i),
|
|
|
|
}}
|
|
|
|
writeEntitySegment(t, core, now, i, entity)
|
|
|
|
writeTokenSegment(t, core, now, i, token)
|
|
|
|
entities = append(entities, entity)
|
|
|
|
tokens = append(tokens, token)
|
|
|
|
}
|
|
|
|
|
|
|
|
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
gotEntities := make([]*activity.EntityActivityLog, 0, 3)
|
|
|
|
gotTokens := make([]*activity.TokenCount, 0, 3)
|
|
|
|
|
|
|
|
// read the entities from the reader
|
|
|
|
for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) {
|
|
|
|
require.NoError(t, err)
|
|
|
|
gotEntities = append(gotEntities, entity)
|
|
|
|
}
|
|
|
|
|
|
|
|
// read the tokens from the reader
|
|
|
|
for token, err := reader.ReadToken(context.Background()); !errors.Is(err, io.EOF); token, err = reader.ReadToken(context.Background()) {
|
|
|
|
require.NoError(t, err)
|
|
|
|
gotTokens = append(gotTokens, token)
|
|
|
|
}
|
|
|
|
require.Len(t, gotEntities, 3)
|
|
|
|
require.Len(t, gotTokens, 3)
|
|
|
|
|
|
|
|
// verify that the entities and tokens we got from the reader are correct
|
|
|
|
// we can't use require.Equals() here because there are protobuf differences in unexported fields
|
|
|
|
for i := 0; i < 3; i++ {
|
|
|
|
require.True(t, proto.Equal(gotEntities[i], entities[i]))
|
|
|
|
require.True(t, proto.Equal(gotTokens[i], tokens[i]))
|
|
|
|
}
|
|
|
|
}
|