open-vault/vault/activity_log_test.go

2484 lines
66 KiB
Go
Raw Normal View History

package vault
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/go-test/deep"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/activity"
)
func TestActivityLog_Creation(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
a.SetEnable(true)
if a == nil {
t.Fatal("no activity log found")
}
if a.logger == nil || a.view == nil {
t.Fatal("activity log not initialized")
}
if a.fragment != nil {
t.Fatal("activity log already has fragment")
}
const entity_id = "entity_id_75432"
const namespace_id = "ns123"
ts := time.Now()
a.AddEntityToFragment(entity_id, namespace_id, ts.Unix())
if a.fragment == nil {
t.Fatal("no fragment created")
}
if a.fragment.OriginatingNode != a.nodeID {
t.Errorf("mismatched node ID, %q vs %q", a.fragment.OriginatingNode, a.nodeID)
}
if a.fragment.Entities == nil {
t.Fatal("no fragment entity slice")
}
if a.fragment.NonEntityTokens == nil {
t.Fatal("no fragment token map")
}
if len(a.fragment.Entities) != 1 {
t.Fatalf("wrong number of entities %v", len(a.fragment.Entities))
}
er := a.fragment.Entities[0]
if er.EntityID != entity_id {
t.Errorf("mimatched entity ID, %q vs %q", er.EntityID, entity_id)
}
if er.NamespaceID != namespace_id {
t.Errorf("mimatched namespace ID, %q vs %q", er.NamespaceID, namespace_id)
}
if er.Timestamp != ts.Unix() {
t.Errorf("mimatched timestamp, %v vs %v", er.Timestamp, ts.Unix())
}
// Reset and test the other code path
a.fragment = nil
a.AddTokenToFragment(namespace_id)
if a.fragment == nil {
t.Fatal("no fragment created")
}
if a.fragment.NonEntityTokens == nil {
t.Fatal("no fragment token map")
}
actual := a.fragment.NonEntityTokens[namespace_id]
if actual != 1 {
t.Errorf("mismatched number of tokens, %v vs %v", actual, 1)
}
}
func checkExpectedEntitiesInMap(t *testing.T, a *ActivityLog, entityIDs []string) {
t.Helper()
activeEntities := a.core.GetActiveEntities()
if len(activeEntities) != len(entityIDs) {
t.Fatalf("mismatched number of entities, expected %v got %v", len(entityIDs), activeEntities)
}
for _, e := range entityIDs {
if _, present := activeEntities[e]; !present {
t.Errorf("entity ID %q is missing", e)
}
}
}
func TestActivityLog_UniqueEntities(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
a.SetEnable(true)
id1 := "11111111-1111-1111-1111-111111111111"
t1 := time.Now()
id2 := "22222222-2222-2222-2222-222222222222"
t2 := time.Now()
t3 := t2.Add(60 * time.Second)
a.AddEntityToFragment(id1, "root", t1.Unix())
a.AddEntityToFragment(id2, "root", t2.Unix())
a.AddEntityToFragment(id2, "root", t3.Unix())
a.AddEntityToFragment(id1, "root", t3.Unix())
if a.fragment == nil {
t.Fatal("no current fragment")
}
if len(a.fragment.Entities) != 2 {
t.Fatalf("number of entities is %v", len(a.fragment.Entities))
}
for i, e := range a.fragment.Entities {
expectedID := id1
expectedTime := t1.Unix()
expectedNS := "root"
if i == 1 {
expectedID = id2
expectedTime = t2.Unix()
}
if e.EntityID != expectedID {
t.Errorf("%v: expected %q, got %q", i, expectedID, e.EntityID)
}
if e.NamespaceID != expectedNS {
t.Errorf("%v: expected %q, got %q", i, expectedNS, e.NamespaceID)
}
if e.Timestamp != expectedTime {
t.Errorf("%v: expected %v, got %v", i, expectedTime, e.Timestamp)
}
}
checkExpectedEntitiesInMap(t, a, []string{id1, id2})
}
func readSegmentFromStorage(t *testing.T, c *Core, path string) *logical.StorageEntry {
t.Helper()
logSegment, err := c.barrier.Get(context.Background(), path)
if err != nil {
t.Fatal(err)
}
if logSegment == nil {
t.Fatalf("expected non-nil log segment at %q", path)
}
return logSegment
}
func expectMissingSegment(t *testing.T, c *Core, path string) {
t.Helper()
logSegment, err := c.barrier.Get(context.Background(), path)
if err != nil {
t.Fatal(err)
}
if logSegment != nil {
t.Fatalf("expected nil log segment at %q", path)
}
}
func expectedEntityIDs(t *testing.T, out *activity.EntityActivityLog, ids []string) {
t.Helper()
if len(out.Entities) != len(ids) {
t.Fatalf("entity log expected length %v, actual %v", len(ids), len(out.Entities))
}
// Double loop, OK for small cases
for _, id := range ids {
found := false
for _, e := range out.Entities {
if e.EntityID == id {
found = true
break
}
}
if !found {
t.Errorf("did not find entity ID %v", id)
}
}
}
func TestActivityLog_SaveTokensToStorage(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
ctx := context.Background()
a := core.activityLog
a.SetStandbyEnable(ctx, true)
a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment
nsIDs := [...]string{"ns1_id", "ns2_id", "ns3_id"}
path := fmt.Sprintf("%sdirecttokens/%d/0", ActivityLogPrefix, a.GetStartTimestamp())
for i := 0; i < 3; i++ {
a.AddTokenToFragment(nsIDs[0])
}
a.AddTokenToFragment(nsIDs[1])
err := a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatalf("got error writing tokens to storage: %v", err)
}
if a.fragment != nil {
t.Errorf("fragment was not reset after write to storage")
}
protoSegment := readSegmentFromStorage(t, core, path)
out := &activity.TokenCount{}
err = proto.Unmarshal(protoSegment.Value, out)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
if len(out.CountByNamespaceID) != 2 {
t.Fatalf("unexpected token length. Expected %d, got %d", 2, len(out.CountByNamespaceID))
}
for i := 0; i < 2; i++ {
if _, ok := out.CountByNamespaceID[nsIDs[i]]; !ok {
t.Fatalf("namespace ID %s missing from token counts", nsIDs[i])
}
}
if out.CountByNamespaceID[nsIDs[0]] != 3 {
t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[0], out.CountByNamespaceID[nsIDs[0]], 3)
}
if out.CountByNamespaceID[nsIDs[1]] != 1 {
t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[1], out.CountByNamespaceID[nsIDs[1]], 1)
}
a.AddTokenToFragment(nsIDs[0])
a.AddTokenToFragment(nsIDs[2])
err = a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatalf("got error writing tokens to storage: %v", err)
}
if a.fragment != nil {
t.Errorf("fragment was not reset after write to storage")
}
protoSegment = readSegmentFromStorage(t, core, path)
out = &activity.TokenCount{}
err = proto.Unmarshal(protoSegment.Value, out)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
if len(out.CountByNamespaceID) != 3 {
t.Fatalf("unexpected token length. Expected %d, got %d", 3, len(out.CountByNamespaceID))
}
for i := 0; i < 3; i++ {
if _, ok := out.CountByNamespaceID[nsIDs[i]]; !ok {
t.Fatalf("namespace ID %s missing from token counts", nsIDs[i])
}
}
if out.CountByNamespaceID[nsIDs[0]] != 4 {
t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[0], out.CountByNamespaceID[nsIDs[0]], 4)
}
if out.CountByNamespaceID[nsIDs[1]] != 1 {
t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[1], out.CountByNamespaceID[nsIDs[1]], 1)
}
if out.CountByNamespaceID[nsIDs[2]] != 1 {
t.Errorf("namespace ID %s has %d count, expected %d", nsIDs[2], out.CountByNamespaceID[nsIDs[2]], 1)
}
}
func TestActivityLog_SaveEntitiesToStorage(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
ctx := context.Background()
a := core.activityLog
a.SetStandbyEnable(ctx, true)
a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment
now := time.Now()
ids := []string{"11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222", "33333333-2222-2222-2222-222222222222"}
times := [...]int64{
now.Unix(),
now.Add(1 * time.Second).Unix(),
now.Add(2 * time.Second).Unix(),
}
path := fmt.Sprintf("%sentity/%d/0", ActivityLogPrefix, a.GetStartTimestamp())
a.AddEntityToFragment(ids[0], "root", times[0])
a.AddEntityToFragment(ids[1], "root2", times[1])
err := a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatalf("got error writing entities to storage: %v", err)
}
if a.fragment != nil {
t.Errorf("fragment was not reset after write to storage")
}
protoSegment := readSegmentFromStorage(t, core, path)
out := &activity.EntityActivityLog{}
err = proto.Unmarshal(protoSegment.Value, out)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
expectedEntityIDs(t, out, ids[:2])
a.AddEntityToFragment(ids[0], "root", times[2])
a.AddEntityToFragment(ids[2], "root", times[2])
err = a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatalf("got error writing segments to storage: %v", err)
}
protoSegment = readSegmentFromStorage(t, core, path)
out = &activity.EntityActivityLog{}
err = proto.Unmarshal(protoSegment.Value, out)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
expectedEntityIDs(t, out, ids)
}
func TestActivityLog_ReceivedFragment(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
a.SetEnable(true)
ids := []string{
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
}
entityRecords := []*activity.EntityRecord{
{
EntityID: ids[0],
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
{
EntityID: ids[1],
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
}
fragment := &activity.LogFragment{
OriginatingNode: "test-123",
Entities: entityRecords,
NonEntityTokens: make(map[string]uint64),
}
if len(a.standbyFragmentsReceived) != 0 {
t.Fatalf("fragment already received")
}
a.receivedFragment(fragment)
checkExpectedEntitiesInMap(t, a, ids)
if len(a.standbyFragmentsReceived) != 1 {
t.Fatalf("fragment count is %v, expected 1", len(a.standbyFragmentsReceived))
}
// Send a duplicate, should be stored but not change entity map
a.receivedFragment(fragment)
checkExpectedEntitiesInMap(t, a, ids)
if len(a.standbyFragmentsReceived) != 2 {
t.Fatalf("fragment count is %v, expected 2", len(a.standbyFragmentsReceived))
}
}
func TestActivityLog_availableLogsEmptyDirectory(t *testing.T) {
// verify that directory is empty, and nothing goes wrong
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
times, err := a.availableLogs(context.Background())
if err != nil {
t.Fatalf("error getting start_time(s) for empty activity log")
}
if len(times) != 0 {
t.Fatalf("invalid number of start_times returned. expected 0, got %d", len(times))
}
}
func TestActivityLog_availableLogs(t *testing.T) {
// set up a few files in storage
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
paths := [...]string{"entity/1111/1", "directtokens/1111/1", "directtokens/1000000/1", "entity/992/3", "directtokens/992/1"}
expectedTimes := [...]time.Time{time.Unix(1000000, 0), time.Unix(1111, 0), time.Unix(992, 0)}
for _, path := range paths {
WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test"))
}
// verify above files are there, and dates in correct order
times, err := a.availableLogs(context.Background())
if err != nil {
t.Fatalf("error getting start_time(s) for activity log")
}
if len(times) != len(expectedTimes) {
t.Fatalf("invalid number of start_times returned. expected %d, got %d", len(expectedTimes), len(times))
}
for i := range times {
if !times[i].Equal(expectedTimes[i]) {
t.Errorf("invalid time. expected %v, got %v", expectedTimes[i], times[i])
}
}
}
func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
// enabled check is now inside AddEntityToFragment
a.SetEnable(true)
a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment
// Stop timers for test purposes
close(a.doneCh)
startTimestamp := a.GetStartTimestamp()
path0 := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", startTimestamp)
path1 := fmt.Sprintf("sys/counters/activity/log/entity/%d/1", startTimestamp)
tokenPath := fmt.Sprintf("sys/counters/activity/log/directtokens/%d/0", startTimestamp)
genID := func(i int) string {
return fmt.Sprintf("11111111-1111-1111-1111-%012d", i)
}
ts := time.Now().Unix()
// First 7000 should fit in one segment
for i := 0; i < 7000; i++ {
a.AddEntityToFragment(genID(i), "root", ts)
}
// Consume new fragment notification.
// The worker may have gotten it first, before processing
// the close!
select {
case <-a.newFragmentCh:
default:
}
// Save incomplete segment
err := a.saveCurrentSegmentToStorage(context.Background(), false)
if err != nil {
t.Fatalf("got error writing entities to storage: %v", err)
}
protoSegment0 := readSegmentFromStorage(t, core, path0)
entityLog0 := activity.EntityActivityLog{}
err = proto.Unmarshal(protoSegment0.Value, &entityLog0)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
if len(entityLog0.Entities) != 7000 {
t.Fatalf("unexpected entity length. Expected %d, got %d", 7000, len(entityLog0.Entities))
}
// 7000 more local entities
for i := 7000; i < 14000; i++ {
a.AddEntityToFragment(genID(i), "root", ts)
}
// Simulated remote fragment with 100 duplicate entities
tokens1 := map[string]uint64{
"root": 3,
"aaaaa": 4,
"bbbbb": 5,
}
fragment1 := &activity.LogFragment{
OriginatingNode: "test-123",
Entities: make([]*activity.EntityRecord, 0, 100),
NonEntityTokens: tokens1,
}
for i := 7000; i < 7100; i++ {
fragment1.Entities = append(fragment1.Entities, &activity.EntityRecord{
EntityID: genID(i),
NamespaceID: "root",
Timestamp: ts,
})
}
// Simulated remote fragment with 100 new entities
tokens2 := map[string]uint64{
"root": 6,
"aaaaa": 7,
"bbbbb": 8,
}
fragment2 := &activity.LogFragment{
OriginatingNode: "test-123",
Entities: make([]*activity.EntityRecord, 0, 100),
NonEntityTokens: tokens2,
}
for i := 14000; i < 14100; i++ {
fragment2.Entities = append(fragment2.Entities, &activity.EntityRecord{
EntityID: genID(i),
NamespaceID: "root",
Timestamp: ts,
})
}
a.receivedFragment(fragment1)
a.receivedFragment(fragment2)
<-a.newFragmentCh
err = a.saveCurrentSegmentToStorage(context.Background(), false)
if err != nil {
t.Fatalf("got error writing entities to storage: %v", err)
}
seqNum := a.GetEntitySequenceNumber()
if seqNum != 1 {
t.Fatalf("expected sequence number 1, got %v", seqNum)
}
protoSegment0 = readSegmentFromStorage(t, core, path0)
err = proto.Unmarshal(protoSegment0.Value, &entityLog0)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
if len(entityLog0.Entities) != activitySegmentEntityCapacity {
t.Fatalf("unexpected entity length. Expected %d, got %d", activitySegmentEntityCapacity,
len(entityLog0.Entities))
}
protoSegment1 := readSegmentFromStorage(t, core, path1)
entityLog1 := activity.EntityActivityLog{}
err = proto.Unmarshal(protoSegment1.Value, &entityLog1)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
expectedCount := 14100 - activitySegmentEntityCapacity
if len(entityLog1.Entities) != expectedCount {
t.Fatalf("unexpected entity length. Expected %d, got %d", expectedCount,
len(entityLog1.Entities))
}
entityPresent := make(map[string]struct{})
for _, e := range entityLog0.Entities {
entityPresent[e.EntityID] = struct{}{}
}
for _, e := range entityLog1.Entities {
entityPresent[e.EntityID] = struct{}{}
}
for i := 0; i < 14100; i++ {
expectedID := genID(i)
if _, present := entityPresent[expectedID]; !present {
t.Fatalf("entity ID %v = %v not present", i, expectedID)
}
}
expectedNSCounts := map[string]uint64{
"root": 9,
"aaaaa": 11,
"bbbbb": 13,
}
tokenSegment := readSegmentFromStorage(t, core, tokenPath)
tokenCount := activity.TokenCount{}
err = proto.Unmarshal(tokenSegment.Value, &tokenCount)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
if !reflect.DeepEqual(expectedNSCounts, tokenCount.CountByNamespaceID) {
t.Fatalf("token counts are not equal, expected %v got %v", expectedNSCounts, tokenCount.CountByNamespaceID)
}
}
func TestActivityLog_API_ConfigCRUD(t *testing.T) {
core, b, _ := testCoreSystemBackend(t)
view := core.systemBarrierView
// Test reading the defaults
{
req := logical.TestRequest(t, logical.ReadOperation, "internal/counters/config")
req.Storage = view
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
defaults := map[string]interface{}{
"default_report_months": 12,
"retention_months": 24,
"enabled": activityLogEnabledDefaultValue,
"queries_available": false,
}
if diff := deep.Equal(resp.Data, defaults); len(diff) > 0 {
t.Fatalf("diff: %v", diff)
}
}
// Check Error Cases
{
req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["default_report_months"] = 0
_, err := b.HandleRequest(namespace.RootContext(nil), req)
if err == nil {
t.Fatal("expected error")
}
req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["enabled"] = "bad-value"
_, err = b.HandleRequest(namespace.RootContext(nil), req)
if err == nil {
t.Fatal("expected error")
}
req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["retention_months"] = 0
req.Data["enabled"] = "enable"
_, err = b.HandleRequest(namespace.RootContext(nil), req)
if err == nil {
t.Fatal("expected error")
}
}
// Test single key updates
{
req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["default_report_months"] = 1
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["retention_months"] = 2
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
req = logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["enabled"] = "enable"
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
req = logical.TestRequest(t, logical.ReadOperation, "internal/counters/config")
req.Storage = view
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
expected := map[string]interface{}{
"default_report_months": 1,
"retention_months": 2,
"enabled": "enable",
"queries_available": false,
}
if diff := deep.Equal(resp.Data, expected); len(diff) > 0 {
t.Fatalf("diff: %v", diff)
}
}
// Test updating all keys
{
req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["enabled"] = "default"
req.Data["retention_months"] = 24
req.Data["default_report_months"] = 12
originalEnabled := core.activityLog.GetEnabled()
newEnabled := activityLogEnabledDefault
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
checkAPIWarnings(t, originalEnabled, newEnabled, resp)
req = logical.TestRequest(t, logical.ReadOperation, "internal/counters/config")
req.Storage = view
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
defaults := map[string]interface{}{
"default_report_months": 12,
"retention_months": 24,
"enabled": activityLogEnabledDefaultValue,
"queries_available": false,
}
if diff := deep.Equal(resp.Data, defaults); len(diff) > 0 {
t.Fatalf("diff: %v", diff)
}
}
}
func TestActivityLog_parseSegmentNumberFromPath(t *testing.T) {
testCases := []struct {
input string
expected int
expectExists bool
}{
{
input: "path/to/log/5",
expected: 5,
expectExists: true,
},
{
input: "/path/to/log/5",
expected: 5,
expectExists: true,
},
{
input: "path/to/log/",
expected: 0,
expectExists: false,
},
{
input: "path/to/log/foo",
expected: 0,
expectExists: false,
},
{
input: "",
expected: 0,
expectExists: false,
},
{
input: "5",
expected: 5,
expectExists: true,
},
}
for _, tc := range testCases {
result, ok := parseSegmentNumberFromPath(tc.input)
if result != tc.expected {
t.Errorf("expected: %d, got: %d for input %q", tc.expected, result, tc.input)
}
if ok != tc.expectExists {
t.Errorf("unexpected value presence. expected exists: %t, got: %t for input %q", tc.expectExists, ok, tc.input)
}
}
}
func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
paths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/0", "entity/1111/1"}
for _, path := range paths {
WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test"))
}
testCases := []struct {
input int64
expectedVal uint64
expectExists bool
}{
{
input: 992,
expectedVal: 0,
expectExists: true,
},
{
input: 1000,
expectedVal: 0,
expectExists: false,
},
{
input: 1001,
expectedVal: 0,
expectExists: false,
},
{
input: 1111,
expectedVal: 1,
expectExists: true,
},
{
input: 2222,
expectedVal: 0,
expectExists: false,
},
}
ctx := context.Background()
for _, tc := range testCases {
result, exists, err := a.getLastEntitySegmentNumber(ctx, time.Unix(tc.input, 0))
if err != nil {
t.Fatalf("unexpected error for input %d: %v", tc.input, err)
}
if exists != tc.expectExists {
t.Errorf("expected result exists: %t, got: %t for input: %d", tc.expectExists, exists, tc.input)
}
if result != tc.expectedVal {
t.Errorf("expected: %d got: %d for input: %d", tc.expectedVal, result, tc.input)
}
}
}
func TestActivityLog_tokenCountExists(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
paths := [...]string{"directtokens/992/0", "directtokens/1001/foo", "directtokens/1111/0", "directtokens/2222/1"}
for _, path := range paths {
WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test"))
}
testCases := []struct {
input int64
expectExists bool
}{
{
input: 992,
expectExists: true,
},
{
input: 1001,
expectExists: false,
},
{
input: 1111,
expectExists: true,
},
{
input: 2222,
expectExists: false,
},
}
ctx := context.Background()
for _, tc := range testCases {
exists, err := a.tokenCountExists(ctx, time.Unix(tc.input, 0))
if err != nil {
t.Fatalf("unexpected error for input %d: %v", tc.input, err)
}
if exists != tc.expectExists {
t.Errorf("expected segment to exist: %t but got: %t for input: %d", tc.expectExists, exists, tc.input)
}
}
}
// entityRecordsEqual compares the parts we care about from two activity entity record slices
// note: this makes a copy of the []*activity.EntityRecord so that misordered slices won't fail the comparison,
// but the function won't modify the order of the slices to compare
func entityRecordsEqual(t *testing.T, record1, record2 []*activity.EntityRecord) bool {
t.Helper()
if record1 == nil {
return record2 == nil
}
if record2 == nil {
return record1 == nil
}
if len(record1) != len(record2) {
return false
}
// sort first on namespace, then on ID, then on timestamp
entityLessFn := func(e []*activity.EntityRecord, i, j int) bool {
ei := e[i]
ej := e[j]
nsComp := strings.Compare(ei.NamespaceID, ej.NamespaceID)
if nsComp == -1 {
return true
}
if nsComp == 1 {
return false
}
idComp := strings.Compare(ei.EntityID, ej.EntityID)
if idComp == -1 {
return true
}
if idComp == 1 {
return false
}
return ei.Timestamp < ej.Timestamp
}
entitiesCopy1 := make([]*activity.EntityRecord, len(record1))
entitiesCopy2 := make([]*activity.EntityRecord, len(record2))
copy(entitiesCopy1, record1)
copy(entitiesCopy2, record2)
sort.Slice(entitiesCopy1, func(i, j int) bool {
return entityLessFn(entitiesCopy1, i, j)
})
sort.Slice(entitiesCopy2, func(i, j int) bool {
return entityLessFn(entitiesCopy2, i, j)
})
for i, a := range entitiesCopy1 {
b := entitiesCopy2[i]
if a.EntityID != b.EntityID || a.NamespaceID != b.NamespaceID || a.Timestamp != b.Timestamp {
return false
}
}
return true
}
func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) {
t.Helper()
a.l.Lock()
defer a.l.Unlock()
a.fragmentLock.Lock()
defer a.fragmentLock.Unlock()
a.currentSegment = segmentInfo{
startTimestamp: time.Time{}.Unix(),
currentEntities: &activity.EntityActivityLog{
Entities: make([]*activity.EntityRecord, 0),
},
tokenCount: a.currentSegment.tokenCount,
entitySequenceNumber: 0,
}
a.activeEntities = make(map[string]struct{})
}
func TestActivityLog_loadCurrentEntitySegment(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
// we must verify that loadCurrentEntitySegment doesn't overwrite the in-memory token count
tokenRecords := make(map[string]uint64)
tokenRecords["test"] = 1
tokenCount := &activity.TokenCount{
CountByNamespaceID: tokenRecords,
}
a.SetTokenCount(tokenCount)
// setup in-storage data to load for testing
entityRecords := []*activity.EntityRecord{
{
EntityID: "11111111-1111-1111-1111-111111111111",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
{
EntityID: "22222222-2222-2222-2222-222222222222",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
}
testEntities1 := &activity.EntityActivityLog{
Entities: entityRecords[:1],
}
testEntities2 := &activity.EntityActivityLog{
Entities: entityRecords[1:2],
}
testEntities3 := &activity.EntityActivityLog{
Entities: entityRecords[:2],
}
time1 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC).Unix()
time2 := time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC).Unix()
testCases := []struct {
time int64
seqNum uint64
path string
entities *activity.EntityActivityLog
}{
{
time: time1,
seqNum: 0,
path: "entity/" + fmt.Sprint(time1) + "/0",
entities: testEntities1,
},
{
// we want to verify that data from segment 0 hasn't been loaded
time: time1,
seqNum: 1,
path: "entity/" + fmt.Sprint(time1) + "/1",
entities: testEntities2,
},
{
time: time2,
seqNum: 0,
path: "entity/" + fmt.Sprint(time2) + "/0",
entities: testEntities3,
},
}
for _, tc := range testCases {
data, err := proto.Marshal(tc.entities)
if err != nil {
t.Fatalf(err.Error())
}
WriteToStorage(t, core, ActivityLogPrefix+tc.path, data)
}
ctx := context.Background()
for _, tc := range testCases {
err := a.loadCurrentEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum)
if err != nil {
t.Fatalf("got error loading data for %q: %v", tc.path, err)
}
if !reflect.DeepEqual(a.GetCountByNamespaceID(), tokenCount.CountByNamespaceID) {
t.Errorf("this function should not wipe out the in-memory token count")
}
// verify accurate data in in-memory current segment
startTimestamp := a.GetStartTimestamp()
if startTimestamp != tc.time {
t.Errorf("bad timestamp loaded. expected: %v, got: %v for path %q", tc.time, startTimestamp, tc.path)
}
seqNum := a.GetEntitySequenceNumber()
if seqNum != tc.seqNum {
t.Errorf("bad sequence number loaded. expected: %v, got: %v for path %q", tc.seqNum, seqNum, tc.path)
}
currentEntities := a.GetCurrentEntities()
if !entityRecordsEqual(t, currentEntities.Entities, tc.entities.Entities) {
t.Errorf("bad data loaded. expected: %v, got: %v for path %q", tc.entities.Entities, currentEntities, tc.path)
}
activeEntities := core.GetActiveEntities()
if !ActiveEntitiesEqual(activeEntities, tc.entities.Entities) {
t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v for path %q", tc.entities.Entities, activeEntities, tc.path)
}
a.resetEntitiesInMemory(t)
}
}
func TestActivityLog_loadPriorEntitySegment(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
a.SetEnable(true)
// setup in-storage data to load for testing
entityRecords := []*activity.EntityRecord{
{
EntityID: "11111111-1111-1111-1111-111111111111",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
{
EntityID: "22222222-2222-2222-2222-222222222222",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
}
testEntities1 := &activity.EntityActivityLog{
Entities: entityRecords[:1],
}
testEntities2 := &activity.EntityActivityLog{
Entities: entityRecords[:2],
}
time1 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC).Unix()
time2 := time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC).Unix()
testCases := []struct {
time int64
seqNum uint64
path string
entities *activity.EntityActivityLog
// set true if the in-memory active entities should be wiped because the next test case is a new month
// this also means that currentSegment.startTimestamp must be updated with :time:
refresh bool
}{
{
time: time1,
seqNum: 0,
path: "entity/" + fmt.Sprint(time1) + "/0",
entities: testEntities1,
refresh: true,
},
{
// verify that we don't have a duplicate (shouldn't be possible with the current implementation)
time: time1,
seqNum: 1,
path: "entity/" + fmt.Sprint(time1) + "/1",
entities: testEntities2,
refresh: true,
},
{
time: time2,
seqNum: 0,
path: "entity/" + fmt.Sprint(time2) + "/0",
entities: testEntities2,
refresh: true,
},
}
for _, tc := range testCases {
data, err := proto.Marshal(tc.entities)
if err != nil {
t.Fatalf(err.Error())
}
WriteToStorage(t, core, ActivityLogPrefix+tc.path, data)
}
ctx := context.Background()
for _, tc := range testCases {
if tc.refresh {
a.l.Lock()
a.fragmentLock.Lock()
a.activeEntities = make(map[string]struct{})
a.currentSegment.startTimestamp = tc.time
a.fragmentLock.Unlock()
a.l.Unlock()
}
err := a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum)
if err != nil {
t.Fatalf("got error loading data for %q: %v", tc.path, err)
}
activeEntities := core.GetActiveEntities()
if !ActiveEntitiesEqual(activeEntities, tc.entities.Entities) {
t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v for path %q", tc.entities.Entities, activeEntities, tc.path)
}
}
}
func TestActivityLog_loadTokenCount(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
// setup in-storage data to load for testing
tokenRecords := make(map[string]uint64)
for i := 1; i < 4; i++ {
nsID := "ns" + strconv.Itoa(i)
tokenRecords[nsID] = uint64(i)
}
tokenCount := &activity.TokenCount{
CountByNamespaceID: tokenRecords,
}
data, err := proto.Marshal(tokenCount)
if err != nil {
t.Fatalf(err.Error())
}
testCases := []struct {
time int64
path string
}{
{
time: 1111,
path: "directtokens/1111/0",
},
{
time: 2222,
path: "directtokens/2222/0",
},
}
for _, tc := range testCases {
WriteToStorage(t, core, ActivityLogPrefix+tc.path, data)
}
ctx := context.Background()
for _, tc := range testCases {
// a.currentSegment.tokenCount doesn't need to be wiped each iter since it happens in loadTokenSegment()
err := a.loadTokenCount(ctx, time.Unix(tc.time, 0))
if err != nil {
t.Fatalf("got error loading data for %q: %v", tc.path, err)
}
nsCount := a.GetCountByNamespaceID()
if !reflect.DeepEqual(nsCount, tokenRecords) {
t.Errorf("bad token count loaded. expected: %v got: %v for path %q", tokenRecords, nsCount, tc.path)
}
}
}
func TestActivityLog_StopAndRestart(t *testing.T) {
core, b, _ := testCoreSystemBackend(t)
sysView := core.systemBarrierView
a := core.activityLog
ctx := namespace.RootContext(nil)
// Disable, then enable, to exercise newly-enabled code
a.SetConfig(ctx, activityConfig{
Enabled: "disable",
RetentionMonths: 12,
DefaultReportMonths: 12,
})
// On enterprise, a segment will be created, and
// disabling it will trigger deletion, so wait
// for that deletion to finish.
// (Alternatively, we could ensure that the next segment
// uses a different timestamp by waiting 1 second.)
a.WaitForDeletion()
// Go through request to ensure config is persisted
req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = sysView
req.Data["enabled"] = "enable"
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %#v", resp)
}
// Simulate seal/unseal cycle
core.stopActivityLog()
var wg sync.WaitGroup
core.setupActivityLog(ctx, &wg)
wg.Wait()
a = core.activityLog
if a.GetCountByNamespaceID() == nil {
t.Fatalf("nil token count map")
}
a.AddEntityToFragment("1111-1111", "root", time.Now().Unix())
a.AddTokenToFragment("root")
err = a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatal(err)
}
}
// :base: is the timestamp to start from for the setup logic (use to simulate newest log from past or future)
// entity records returned include [0] data from a previous month and [1:] data from the current month
// token counts returned are from the current month
func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities, includeTokens bool) (*ActivityLog, []*activity.EntityRecord, map[string]uint64) {
t.Helper()
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
monthsAgo := base.AddDate(0, -3, 0)
var entityRecords []*activity.EntityRecord
if includeEntities {
entityRecords = []*activity.EntityRecord{
{
EntityID: "11111111-1111-1111-1111-111111111111",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
{
EntityID: "22222222-2222-2222-2222-222222222222",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
{
EntityID: "33333333-2222-2222-2222-222222222222",
NamespaceID: "root",
Timestamp: time.Now().Unix(),
},
}
testEntities1 := &activity.EntityActivityLog{
Entities: entityRecords[:1],
}
entityData1, err := proto.Marshal(testEntities1)
if err != nil {
t.Fatalf(err.Error())
}
testEntities2 := &activity.EntityActivityLog{
Entities: entityRecords[1:2],
}
entityData2, err := proto.Marshal(testEntities2)
if err != nil {
t.Fatalf(err.Error())
}
testEntities3 := &activity.EntityActivityLog{
Entities: entityRecords[2:],
}
entityData3, err := proto.Marshal(testEntities3)
if err != nil {
t.Fatalf(err.Error())
}
WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(monthsAgo.Unix())+"/0", entityData1)
WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(base.Unix())+"/0", entityData2)
WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(base.Unix())+"/1", entityData3)
}
var tokenRecords map[string]uint64
if includeTokens {
tokenRecords = make(map[string]uint64)
for i := 1; i < 4; i++ {
nsID := "ns" + strconv.Itoa(i)
tokenRecords[nsID] = uint64(i)
}
tokenCount := &activity.TokenCount{
CountByNamespaceID: tokenRecords,
}
tokenData, err := proto.Marshal(tokenCount)
if err != nil {
t.Fatalf(err.Error())
}
WriteToStorage(t, core, ActivityLogPrefix+"directtokens/"+fmt.Sprint(base.Unix())+"/0", tokenData)
}
return a, entityRecords, tokenRecords
}
func TestActivityLog_refreshFromStoredLog(t *testing.T) {
a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(context.Background(), &wg, time.Now().UTC())
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
expectedActive := &activity.EntityActivityLog{
Entities: expectedEntityRecords[1:],
}
expectedCurrent := &activity.EntityActivityLog{
Entities: expectedEntityRecords[2:],
}
currentEntities := a.GetCurrentEntities()
if !entityRecordsEqual(t, currentEntities.Entities, expectedCurrent.Entities) {
// we only expect the newest entity segment to be loaded (for the current month)
t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities)
}
nsCount := a.GetCountByNamespaceID()
if !reflect.DeepEqual(nsCount, expectedTokenCounts) {
// we expect all token counts to be loaded
t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount)
}
activeEntities := a.core.GetActiveEntities()
if !ActiveEntitiesEqual(activeEntities, expectedActive.Entities) {
// we expect activeEntities to be loaded for the entire month
t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Entities, activeEntities)
}
}
func TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled(t *testing.T) {
a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true)
a.SetEnable(true)
var wg sync.WaitGroup
close(a.doneCh)
err := a.refreshFromStoredLog(context.Background(), &wg, time.Now().UTC())
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
expected := &activity.EntityActivityLog{
Entities: expectedEntityRecords[2:],
}
currentEntities := a.GetCurrentEntities()
if !entityRecordsEqual(t, currentEntities.Entities, expected.Entities) {
// we only expect the newest entity segment to be loaded (for the current month)
t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expected, currentEntities)
}
nsCount := a.GetCountByNamespaceID()
if !reflect.DeepEqual(nsCount, expectedTokenCounts) {
// we expect all token counts to be loaded
t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount)
}
activeEntities := a.core.GetActiveEntities()
if !ActiveEntitiesEqual(activeEntities, expected.Entities) {
// we only expect activeEntities to be loaded for the newest segment (for the current month)
t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expected.Entities, activeEntities)
}
}
func TestActivityLog_refreshFromStoredLogContextCancelled(t *testing.T) {
a, _, _ := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true)
var wg sync.WaitGroup
ctx, cancelFn := context.WithCancel(context.Background())
cancelFn()
err := a.refreshFromStoredLog(ctx, &wg, time.Now().UTC())
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context cancelled error, got: %v", err)
}
}
func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) {
a, expectedEntityRecords, _ := setupActivityRecordsInStorage(t, time.Now().UTC(), true, false)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(context.Background(), &wg, time.Now().UTC())
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
expectedActive := &activity.EntityActivityLog{
Entities: expectedEntityRecords[1:],
}
expectedCurrent := &activity.EntityActivityLog{
Entities: expectedEntityRecords[2:],
}
currentEntities := a.GetCurrentEntities()
if !entityRecordsEqual(t, currentEntities.Entities, expectedCurrent.Entities) {
// we expect all segments for the current month to be loaded
t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities)
}
activeEntities := a.core.GetActiveEntities()
if !ActiveEntitiesEqual(activeEntities, expectedActive.Entities) {
t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Entities, activeEntities)
}
// we expect no tokens
nsCount := a.GetCountByNamespaceID()
if len(nsCount) > 0 {
t.Errorf("expected no token counts to be loaded. got: %v", nsCount)
}
}
func TestActivityLog_refreshFromStoredLogNoEntities(t *testing.T) {
a, _, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), false, true)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(context.Background(), &wg, time.Now().UTC())
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
nsCount := a.GetCountByNamespaceID()
if !reflect.DeepEqual(nsCount, expectedTokenCounts) {
// we expect all token counts to be loaded
t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount)
}
currentEntities := a.GetCurrentEntities()
if len(currentEntities.Entities) > 0 {
t.Errorf("expected no current entity segment to be loaded. got: %v", currentEntities)
}
activeEntities := a.core.GetActiveEntities()
if len(activeEntities) > 0 {
t.Errorf("expected no active entity segment to be loaded. got: %v", activeEntities)
}
}
func TestActivityLog_refreshFromStoredLogNoData(t *testing.T) {
now := time.Now().UTC()
a, _, _ := setupActivityRecordsInStorage(t, now, false, false)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(context.Background(), &wg, now)
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
a.ExpectCurrentSegmentRefreshed(t, now.Unix(), false)
}
func TestActivityLog_refreshFromStoredLogTwoMonthsPrevious(t *testing.T) {
// test what happens when the most recent data is from month M-2 (or earlier - same effect)
now := time.Now().UTC()
twoMonthsAgoStart := timeutil.StartOfPreviousMonth(timeutil.StartOfPreviousMonth(now))
a, _, _ := setupActivityRecordsInStorage(t, twoMonthsAgoStart, true, true)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(context.Background(), &wg, now)
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
a.ExpectCurrentSegmentRefreshed(t, now.Unix(), false)
}
func TestActivityLog_refreshFromStoredLogPreviousMonth(t *testing.T) {
// test what happens when most recent data is from month M-1
// we expect to load the data from the previous month so that the activeFragmentWorker
// can handle end of month rotations
monthStart := timeutil.StartOfMonth(time.Now().UTC())
oneMonthAgoStart := timeutil.StartOfPreviousMonth(monthStart)
a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, oneMonthAgoStart, true, true)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(context.Background(), &wg, time.Now().UTC())
if err != nil {
t.Fatalf("got error loading stored activity logs: %v", err)
}
wg.Wait()
expectedActive := &activity.EntityActivityLog{
Entities: expectedEntityRecords[1:],
}
expectedCurrent := &activity.EntityActivityLog{
Entities: expectedEntityRecords[2:],
}
currentEntities := a.GetCurrentEntities()
if !entityRecordsEqual(t, currentEntities.Entities, expectedCurrent.Entities) {
// we only expect the newest entity segment to be loaded (for the current month)
t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities)
}
nsCount := a.GetCountByNamespaceID()
if !reflect.DeepEqual(nsCount, expectedTokenCounts) {
// we expect all token counts to be loaded
t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount)
}
activeEntities := a.core.GetActiveEntities()
if !ActiveEntitiesEqual(activeEntities, expectedActive.Entities) {
// we expect activeEntities to be loaded for the entire month
t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Entities, activeEntities)
}
}
func TestActivityLog_IncludeNamespace(t *testing.T) {
root := namespace.RootNamespace
a := &ActivityLog{}
nsA := &namespace.Namespace{
ID: "aaaaa",
Path: "a/",
}
nsC := &namespace.Namespace{
ID: "ccccc",
Path: "c/",
}
nsAB := &namespace.Namespace{
ID: "bbbbb",
Path: "a/b/",
}
testCases := []struct {
QueryNS *namespace.Namespace
RecordNS *namespace.Namespace
Expected bool
}{
{root, nil, true},
{root, root, true},
{root, nsA, true},
{root, nsAB, true},
{nsA, nsA, true},
{nsA, nsAB, true},
{nsAB, nsAB, true},
{nsA, root, false},
{nsA, nil, false},
{nsAB, root, false},
{nsAB, nil, false},
{nsAB, nsA, false},
{nsC, nsA, false},
{nsC, nsAB, false},
}
for _, tc := range testCases {
if a.includeInResponse(tc.QueryNS, tc.RecordNS) != tc.Expected {
t.Errorf("bad response for query %v record %v, expected %v",
tc.QueryNS, tc.RecordNS, tc.Expected)
}
}
}
func TestActivityLog_DeleteWorker(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
paths := []string{
"entity/1111/1",
"entity/1111/2",
"entity/1111/3",
"entity/1112/1",
"directtokens/1111/1",
"directtokens/1112/1",
}
for _, path := range paths {
WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test"))
}
doneCh := make(chan struct{})
timeout := time.After(20 * time.Second)
go a.deleteLogWorker(1111, doneCh)
select {
case <-doneCh:
break
case <-timeout:
t.Fatalf("timed out")
}
// Check segments still present
readSegmentFromStorage(t, core, ActivityLogPrefix+"entity/1112/1")
readSegmentFromStorage(t, core, ActivityLogPrefix+"directtokens/1112/1")
// Check other segments not present
expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/1")
expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/2")
expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/3")
expectMissingSegment(t, core, ActivityLogPrefix+"directtokens/1111/1")
}
// checkAPIWarnings ensures there is a warning if switching from enabled -> disabled,
// and no response otherwise
func checkAPIWarnings(t *testing.T, originalEnabled, newEnabled bool, resp *logical.Response) {
t.Helper()
expectWarning := originalEnabled == true && newEnabled == false
switch {
case !expectWarning && resp != nil:
t.Fatalf("got unexpected response: %#v", resp)
case expectWarning && resp == nil:
t.Fatal("expected response (containing warning) when switching from enabled to disabled")
case expectWarning && len(resp.Warnings) == 0:
t.Fatal("expected warning when switching from enabled to disabled")
}
}
func TestActivityLog_EnableDisable(t *testing.T) {
timeutil.SkipAtEndOfMonth(t)
core, b, _ := testCoreSystemBackend(t)
a := core.activityLog
view := core.systemBarrierView
ctx := namespace.RootContext(nil)
enableRequest := func() {
t.Helper()
originalEnabled := a.GetEnabled()
req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["enabled"] = "enable"
resp, err := b.HandleRequest(ctx, req)
if err != nil {
t.Fatalf("err: %v", err)
}
// don't really need originalEnabled, but might as well be correct
checkAPIWarnings(t, originalEnabled, true, resp)
}
disableRequest := func() {
t.Helper()
originalEnabled := a.GetEnabled()
req := logical.TestRequest(t, logical.UpdateOperation, "internal/counters/config")
req.Storage = view
req.Data["enabled"] = "disable"
resp, err := b.HandleRequest(ctx, req)
if err != nil {
t.Fatalf("err: %v", err)
}
checkAPIWarnings(t, originalEnabled, false, resp)
}
// enable (if not already) and write a segment
enableRequest()
id1 := "11111111-1111-1111-1111-111111111111"
id2 := "22222222-2222-2222-2222-222222222222"
id3 := "33333333-3333-3333-3333-333333333333"
a.AddEntityToFragment(id1, "root", time.Now().Unix())
a.AddEntityToFragment(id2, "root", time.Now().Unix())
a.SetStartTimestamp(a.GetStartTimestamp() - 10)
seg1 := a.GetStartTimestamp()
err := a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatal(err)
}
// verify segment exists
path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, seg1)
readSegmentFromStorage(t, core, path)
// Add in-memory fragment
a.AddEntityToFragment(id3, "root", time.Now().Unix())
// disable and verify segment no longer exists
disableRequest()
timeout := time.After(20 * time.Second)
select {
case <-a.deleteDone:
break
case <-timeout:
t.Fatalf("timed out")
}
expectMissingSegment(t, core, path)
a.ExpectCurrentSegmentRefreshed(t, 0, false)
// enable (if not already) which force-writes an empty segment
enableRequest()
seg2 := a.GetStartTimestamp()
if seg1 >= seg2 {
t.Errorf("bad second segment timestamp, %v >= %v", seg1, seg2)
}
// Verify empty segments are present
path = fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, seg2)
readSegmentFromStorage(t, core, path)
path = fmt.Sprintf("%vdirecttokens/%v/0", ActivityLogPrefix, seg2)
readSegmentFromStorage(t, core, path)
}
func TestActivityLog_EndOfMonth(t *testing.T) {
// We only want *fake* end of months, *real* ones are too scary.
timeutil.SkipAtEndOfMonth(t)
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
ctx := namespace.RootContext(nil)
// Make sure we're enabled.
a.SetConfig(ctx, activityConfig{
Enabled: "enable",
RetentionMonths: 12,
DefaultReportMonths: 12,
})
id1 := "11111111-1111-1111-1111-111111111111"
id2 := "22222222-2222-2222-2222-222222222222"
id3 := "33333333-3333-3333-3333-333333333333"
a.AddEntityToFragment(id1, "root", time.Now().Unix())
month0 := time.Now().UTC()
segment0 := a.GetStartTimestamp()
month1 := timeutil.StartOfNextMonth(month0)
month2 := timeutil.StartOfNextMonth(month1)
// Trigger end-of-month
a.HandleEndOfMonth(month1)
// Check segment is present, with 1 entity
path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, segment0)
protoSegment := readSegmentFromStorage(t, core, path)
out := &activity.EntityActivityLog{}
err := proto.Unmarshal(protoSegment.Value, out)
if err != nil {
t.Fatal(err)
}
segment1 := a.GetStartTimestamp()
expectedTimestamp := timeutil.StartOfMonth(month1).Unix()
if segment1 != expectedTimestamp {
t.Errorf("expected segment timestamp %v got %v", expectedTimestamp, segment1)
}
// Check intent log is present
intentRaw, err := core.barrier.Get(ctx, "sys/counters/activity/endofmonth")
if err != nil {
t.Fatal(err)
}
var intent ActivityIntentLog
err = intentRaw.DecodeJSON(&intent)
if err != nil {
t.Fatal(err)
}
if intent.PreviousMonth != segment0 {
t.Errorf("expected previous month %v got %v", segment0, intent.PreviousMonth)
}
if intent.NextMonth != segment1 {
t.Errorf("expected previous month %v got %v", segment1, intent.NextMonth)
}
a.AddEntityToFragment(id2, "root", time.Now().Unix())
a.HandleEndOfMonth(month2)
segment2 := a.GetStartTimestamp()
a.AddEntityToFragment(id3, "root", time.Now().Unix())
err = a.saveCurrentSegmentToStorage(ctx, false)
if err != nil {
t.Fatal(err)
}
// Check all three segments still present, with correct entities
testCases := []struct {
SegmentTimestamp int64
ExpectedEntityIDs []string
}{
{segment0, []string{id1}},
{segment1, []string{id2}},
{segment2, []string{id3}},
}
for i, tc := range testCases {
t.Logf("checking segment %v timestamp %v", i, tc.SegmentTimestamp)
path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, tc.SegmentTimestamp)
protoSegment := readSegmentFromStorage(t, core, path)
out := &activity.EntityActivityLog{}
err = proto.Unmarshal(protoSegment.Value, out)
if err != nil {
t.Fatalf("could not unmarshal protobuf: %v", err)
}
expectedEntityIDs(t, out, tc.ExpectedEntityIDs)
}
}
func TestActivityLog_SaveAfterDisable(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)
a := core.activityLog
a.SetConfig(ctx, activityConfig{
Enabled: "enable",
RetentionMonths: 12,
DefaultReportMonths: 12,
})
a.AddEntityToFragment("1111-1111-11111111", "root", time.Now().Unix())
startTimestamp := a.GetStartTimestamp()
// This kicks off an asynchronous delete
a.SetConfig(ctx, activityConfig{
Enabled: "disable",
RetentionMonths: 12,
DefaultReportMonths: 12,
})
timer := time.After(10 * time.Second)
select {
case <-timer:
t.Fatal("timeout waiting for delete to finish")
case <-a.deleteDone:
break
}
// Segment should not be written even with force
err := a.saveCurrentSegmentToStorage(context.Background(), true)
if err != nil {
t.Fatal(err)
}
path := ActivityLogPrefix + "entity/0/0"
expectMissingSegment(t, core, path)
path = fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, startTimestamp)
expectMissingSegment(t, core, path)
}
func TestActivityLog_Precompute(t *testing.T) {
timeutil.SkipAtEndOfMonth(t)
january := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
august := time.Date(2020, 8, 15, 12, 0, 0, 0, time.UTC)
september := timeutil.StartOfMonth(time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC))
october := timeutil.StartOfMonth(time.Date(2020, 10, 1, 0, 0, 0, 0, time.UTC))
november := timeutil.StartOfMonth(time.Date(2020, 11, 1, 0, 0, 0, 0, time.UTC))
core, _, _, sink := TestCoreUnsealedWithMetrics(t)
a := core.activityLog
ctx := namespace.RootContext(nil)
// Generate overlapping sets of entity IDs from this list.
// january: 40-44 RRRRR
// first month: 0-19 RRRRRAAAAABBBBBRRRRR
// second month: 10-29 BBBBBRRRRRRRRRRCCCCC
// third month: 15-39 RRRRRRRRRRCCCCCRRRRRBBBBB
entityRecords := make([]*activity.EntityRecord, 45)
entityNamespaces := []string{"root", "aaaaa", "bbbbb", "root", "root", "ccccc", "root", "bbbbb", "rrrrr"}
for i := range entityRecords {
entityRecords[i] = &activity.EntityRecord{
EntityID: fmt.Sprintf("111122222-3333-4444-5555-%012v", i),
NamespaceID: entityNamespaces[i/5],
Timestamp: time.Now().Unix(),
}
}
toInsert := []struct {
StartTime int64
Segment uint64
Entities []*activity.EntityRecord
}{
// January, should not be included
{
january.Unix(),
0,
entityRecords[40:45],
},
// Artifically split August and October
{ // 1
august.Unix(),
0,
entityRecords[:13],
},
{ // 2
august.Unix(),
1,
entityRecords[13:20],
},
{ // 3
september.Unix(),
0,
entityRecords[10:30],
},
{ // 4
october.Unix(),
0,
entityRecords[15:40],
},
{
october.Unix(),
1,
entityRecords[15:40],
},
{
october.Unix(),
2,
entityRecords[17:23],
},
}
// Note that precomputedQuery worker doesn't filter
// for times <= the one it was asked to do. Is that a problem?
// Here, it means that we can't insert everything *first* and do multiple
// test cases, we have to write logs incrementally.
doInsert := func(i int) {
segment := toInsert[i]
eal := &activity.EntityActivityLog{
Entities: segment.Entities,
}
data, err := proto.Marshal(eal)
if err != nil {
t.Fatal(err)
}
path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment)
WriteToStorage(t, core, path, data)
}
expectedCounts := []struct {
StartTime time.Time
EndTime time.Time
ByNamespace map[string]int
}{
// First test case
{
august,
timeutil.EndOfMonth(august),
map[string]int{
"root": 10,
"aaaaa": 5,
"bbbbb": 5,
},
},
// Second test case
{
august,
timeutil.EndOfMonth(september),
map[string]int{
"root": 15,
"aaaaa": 5,
"bbbbb": 5,
"ccccc": 5,
},
},
{
september,
timeutil.EndOfMonth(september),
map[string]int{
"root": 10,
"bbbbb": 5,
"ccccc": 5,
},
},
// Third test case
{
august,
timeutil.EndOfMonth(october),
map[string]int{
"root": 20,
"aaaaa": 5,
"bbbbb": 10,
"ccccc": 5,
},
},
{
september,
timeutil.EndOfMonth(october),
map[string]int{
"root": 15,
"bbbbb": 10,
"ccccc": 5,
},
},
{
october,
timeutil.EndOfMonth(october),
map[string]int{
"root": 15,
"bbbbb": 5,
"ccccc": 5,
},
},
}
checkPrecomputedQuery := func(i int) {
t.Helper()
pq, err := a.queryStore.Get(ctx, expectedCounts[i].StartTime, expectedCounts[i].EndTime)
if err != nil {
t.Fatal(err)
}
if pq == nil {
t.Errorf("empty result for %v -- %v", expectedCounts[i].StartTime, expectedCounts[i].EndTime)
}
if len(pq.Namespaces) != len(expectedCounts[i].ByNamespace) {
t.Errorf("mismatched number of namespaces, expected %v got %v",
len(expectedCounts[i].ByNamespace), len(pq.Namespaces))
}
for _, nsRecord := range pq.Namespaces {
val, ok := expectedCounts[i].ByNamespace[nsRecord.NamespaceID]
if !ok {
t.Errorf("unexpected namespace %v", nsRecord.NamespaceID)
continue
}
if uint64(val) != nsRecord.Entities {
t.Errorf("wrong number of entities in %v: expected %v, got %v",
nsRecord.NamespaceID, val, nsRecord.Entities)
}
}
if !pq.StartTime.Equal(expectedCounts[i].StartTime) {
t.Errorf("mismatched start time: expected %v got %v",
expectedCounts[i].StartTime, pq.StartTime)
}
if !pq.EndTime.Equal(expectedCounts[i].EndTime) {
t.Errorf("mismatched end time: expected %v got %v",
expectedCounts[i].EndTime, pq.EndTime)
}
}
testCases := []struct {
InsertUpTo int // index in the toInsert array
PrevMonth int64
NextMonth int64
ExpectedUpTo int // index in the expectedCounts array
}{
{
2, // jan-august
august.Unix(),
september.Unix(),
0, // august-august
},
{
3, // jan-sept
september.Unix(),
october.Unix(),
2, // august-september
},
{
6, // jan-oct
october.Unix(),
november.Unix(),
5, // august-september
},
}
inserted := -1
for _, tc := range testCases {
t.Logf("tc %+v", tc)
// Persists across loops
for inserted < tc.InsertUpTo {
inserted += 1
t.Logf("inserting segment %v", inserted)
doInsert(inserted)
}
intent := &ActivityIntentLog{
PreviousMonth: tc.PrevMonth,
NextMonth: tc.NextMonth,
}
data, err := json.Marshal(intent)
if err != nil {
t.Fatal(err)
}
WriteToStorage(t, core, "sys/counters/activity/endofmonth", data)
// Pretend we've successfully rolled over to the following month
a.SetStartTimestamp(tc.NextMonth)
err = a.precomputedQueryWorker()
if err != nil {
t.Fatal(err)
}
expectMissingSegment(t, core, "sys/counters/activity/endofmonth")
for i := 0; i <= tc.ExpectedUpTo; i++ {
checkPrecomputedQuery(i)
}
}
// Check metrics on the last precomputed query
// (otherwise we need a way to reset the in-memory metrics between test cases.)
intervals := sink.Data()
// Test crossed an interval boundary, don't try to deal with it.
if len(intervals) > 1 {
t.Skip("Detected interval crossing.")
}
expectedGauges := []struct {
Name string
NamespaceLabel string
Value float32
}{
// october values
{
"identity.entity.active.monthly",
"root",
15.0,
},
{
"identity.entity.active.monthly",
"deleted-bbbbb", // No namespace entry for this fake ID
5.0,
},
{
"identity.entity.active.monthly",
"deleted-ccccc",
5.0,
},
// august-september values
{
"identity.entity.active.reporting_period",
"root",
20.0,
},
{
"identity.entity.active.reporting_period",
"deleted-aaaaa",
5.0,
},
{
"identity.entity.active.reporting_period",
"deleted-bbbbb",
10.0,
},
{
"identity.entity.active.reporting_period",
"deleted-ccccc",
5.0,
},
}
for _, g := range expectedGauges {
found := false
for _, actual := range intervals[0].Gauges {
actualNamespaceLabel := ""
for _, l := range actual.Labels {
if l.Name == "namespace" {
actualNamespaceLabel = l.Value
break
}
}
if actual.Name == g.Name && actualNamespaceLabel == g.NamespaceLabel {
found = true
if actual.Value != g.Value {
t.Errorf("Mismatched value for %v %v %v != %v",
g.Name, g.NamespaceLabel, actual.Value, g.Value)
}
break
}
}
if !found {
t.Errorf("No guage found for %v %v",
g.Name, g.NamespaceLabel)
}
}
}
type BlockingInmemStorage struct{}
func (b *BlockingInmemStorage) List(ctx context.Context, prefix string) ([]string, error) {
<-ctx.Done()
return nil, errors.New("fake implementation")
}
func (b *BlockingInmemStorage) Get(ctx context.Context, key string) (*logical.StorageEntry, error) {
<-ctx.Done()
return nil, errors.New("fake implementation")
}
func (b *BlockingInmemStorage) Put(ctx context.Context, entry *logical.StorageEntry) error {
<-ctx.Done()
return errors.New("fake implementation")
}
func (b *BlockingInmemStorage) Delete(ctx context.Context, key string) error {
<-ctx.Done()
return errors.New("fake implementation")
}
func TestActivityLog_PrecomputeCancel(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
// Substitute in a new view
a.view = NewBarrierView(&BlockingInmemStorage{}, "test")
core.stopActivityLog()
done := make(chan struct{})
// This will block if the shutdown didn't work.
go func() {
a.precomputedQueryWorker()
close(done)
}()
timeout := time.After(5 * time.Second)
select {
case <-done:
break
case <-timeout:
t.Fatalf("timeout waiting for worker to finish")
}
}
func TestActivityLog_NextMonthStart(t *testing.T) {
timeutil.SkipAtEndOfMonth(t)
now := time.Now().UTC()
year, month, _ := now.Date()
computedStart := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC).AddDate(0, 1, 0)
testCases := []struct {
SegmentStart int64
ExpectedTime time.Time
}{
{
0,
computedStart,
},
{
time.Date(2021, 2, 12, 13, 14, 15, 0, time.UTC).Unix(),
time.Date(2021, 3, 1, 0, 0, 0, 0, time.UTC),
},
{
time.Date(2021, 3, 1, 0, 0, 0, 0, time.UTC).Unix(),
time.Date(2021, 4, 1, 0, 0, 0, 0, time.UTC),
},
}
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
for _, tc := range testCases {
t.Logf("segmentStart=%v", tc.SegmentStart)
a.SetStartTimestamp(tc.SegmentStart)
actual := a.StartOfNextMonth()
if !actual.Equal(tc.ExpectedTime) {
t.Errorf("expected %v, got %v", tc.ExpectedTime, actual)
}
}
}
// The retention worker is called on unseal; wait for it to finish before
// proceeding with the test.
func waitForRetentionWorkerToFinish(t *testing.T, a *ActivityLog) {
t.Helper()
timeout := time.After(30 * time.Second)
select {
case <-a.retentionDone:
return
case <-timeout:
t.Fatal("timeout waiting for retention worker to finish")
}
}
func TestActivityLog_Deletion(t *testing.T) {
timeutil.SkipAtEndOfMonth(t)
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
waitForRetentionWorkerToFinish(t, a)
times := []time.Time{
time.Date(2019, 1, 15, 1, 2, 3, 0, time.UTC), // 0
time.Date(2019, 3, 15, 1, 2, 3, 0, time.UTC),
time.Date(2019, 4, 1, 0, 0, 0, 0, time.UTC),
time.Date(2019, 5, 1, 0, 0, 0, 0, time.UTC),
time.Date(2019, 6, 1, 0, 0, 0, 0, time.UTC),
time.Date(2019, 7, 1, 0, 0, 0, 0, time.UTC), // 5
time.Date(2019, 8, 1, 0, 0, 0, 0, time.UTC),
time.Date(2019, 9, 1, 0, 0, 0, 0, time.UTC),
time.Date(2019, 10, 1, 0, 0, 0, 0, time.UTC),
time.Date(2019, 11, 1, 0, 0, 0, 0, time.UTC), // <-- 12 months starts here
time.Date(2019, 12, 1, 0, 0, 0, 0, time.UTC), // 10
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 2, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 3, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 5, 1, 0, 0, 0, 0, time.UTC), // 15
time.Date(2020, 6, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 7, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 8, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 10, 1, 0, 0, 0, 0, time.UTC), // 20
time.Date(2020, 11, 1, 0, 0, 0, 0, time.UTC),
}
novIndex := len(times) - 1
paths := make([][]string, len(times))
for i, start := range times {
// no entities in some months, just for fun
for j := 0; j < (i+3)%5; j++ {
entityPath := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, start.Unix(), j)
paths[i] = append(paths[i], entityPath)
WriteToStorage(t, core, entityPath, []byte("test"))
}
tokenPath := fmt.Sprintf("%vdirecttokens/%v/0", ActivityLogPrefix, start.Unix())
paths[i] = append(paths[i], tokenPath)
WriteToStorage(t, core, tokenPath, []byte("test"))
// No queries for November yet
if i < novIndex {
for _, endTime := range times[i+1 : novIndex] {
queryPath := fmt.Sprintf("sys/counters/activity/queries/%v/%v", start.Unix(), endTime.Unix())
paths[i] = append(paths[i], queryPath)
WriteToStorage(t, core, queryPath, []byte("test"))
}
}
}
checkPresent := func(i int) {
t.Helper()
for _, p := range paths[i] {
readSegmentFromStorage(t, core, p)
}
}
checkAbsent := func(i int) {
t.Helper()
for _, p := range paths[i] {
expectMissingSegment(t, core, p)
}
}
t.Log("24 months")
now := times[len(times)-1]
err := a.retentionWorker(now, 24)
if err != nil {
t.Fatal(err)
}
for i := range times {
checkPresent(i)
}
t.Log("12 months")
err = a.retentionWorker(now, 12)
if err != nil {
t.Fatal(err)
}
for i := 0; i <= 8; i++ {
checkAbsent(i)
}
for i := 9; i <= 21; i++ {
checkPresent(i)
}
t.Log("1 month")
err = a.retentionWorker(now, 1)
if err != nil {
t.Fatal(err)
}
for i := 0; i <= 19; i++ {
checkAbsent(i)
}
checkPresent(20)
checkPresent(21)
t.Log("0 months")
err = a.retentionWorker(now, 0)
if err != nil {
t.Fatal(err)
}
for i := 0; i <= 20; i++ {
checkAbsent(i)
}
checkPresent(21)
}
func TestActivityLog_partialMonthClientCount(t *testing.T) {
timeutil.SkipAtEndOfMonth(t)
ctx := context.Background()
now := time.Now().UTC()
a, entities, tokenCounts := setupActivityRecordsInStorage(t, timeutil.StartOfMonth(now), true, true)
a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(ctx, &wg, now)
if err != nil {
t.Fatalf("error loading clients: %v", err)
}
wg.Wait()
// entities[0] is from a previous month
partialMonthEntityCount := len(entities[1:])
var partialMonthTokenCount int
for _, countByNS := range tokenCounts {
partialMonthTokenCount += int(countByNS)
}
expectedClientCount := partialMonthEntityCount + partialMonthTokenCount
results := a.partialMonthClientCount(ctx)
if results == nil {
t.Fatal("no results to test")
}
entityCount, ok := results["distinct_entities"]
if !ok {
t.Fatalf("malformed results. got %v", results)
}
if entityCount != partialMonthEntityCount {
t.Errorf("bad entity count. expected %d, got %d", partialMonthEntityCount, entityCount)
}
tokenCount, ok := results["non_entity_tokens"]
if !ok {
t.Fatalf("malformed results. got %v", results)
}
if tokenCount != partialMonthTokenCount {
t.Errorf("bad token count. expected %d, got %d", partialMonthTokenCount, tokenCount)
}
clientCount, ok := results["clients"]
if !ok {
t.Fatalf("malformed results. got %v", results)
}
if clientCount != expectedClientCount {
t.Errorf("bad client count. expected %d, got %d", expectedClientCount, clientCount)
}
}