diff --git a/vault/activity_log.go b/vault/activity_log.go index 8d11a7ab5..221ad5c8b 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -409,67 +409,82 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for // :force: forces a save of tokens/entities even if the in-memory log is empty func (a *ActivityLog) saveCurrentSegmentInternal(ctx context.Context, force bool) error { - entityPath := fmt.Sprintf("%s%d/%d", activityEntityBasePath, a.currentSegment.startTimestamp, a.currentSegment.clientSequenceNumber) + _, err := a.saveSegmentEntitiesInternal(ctx, a.currentSegment, force) + if err != nil { + return err + } + _, err = a.saveSegmentTokensInternal(ctx, a.currentSegment, force) + return err +} + +func (a *ActivityLog) saveSegmentTokensInternal(ctx context.Context, currentSegment segmentInfo, force bool) (string, error) { + if len(currentSegment.tokenCount.CountByNamespaceID) == 0 && !force { + return "", nil + } // RFC (VLT-120) defines this as 1-indexed, but it should be 0-indexed - tokenPath := fmt.Sprintf("%s%d/0", activityTokenBasePath, a.currentSegment.startTimestamp) + tokenPath := fmt.Sprintf("%s%d/0", activityTokenBasePath, currentSegment.startTimestamp) + // We must still allow for the tokenCount of the current segment to + // be written to storage, since if we remove this code we will incur + // data loss for one segment's worth of TWEs. + // We can get away with simply using the oldest version stored because + // the storing of versions was introduced at the same time as this code. + oldestVersion, oldestUpgradeTime, err := a.core.FindOldestVersionTimestamp() + switch { + case err != nil: + a.logger.Error(fmt.Sprintf("unable to retrieve oldest version timestamp: %s", err.Error())) + case len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 && + (oldestUpgradeTime.Add(time.Duration(trackedTWESegmentPeriod * time.Hour)).Before(time.Now())): + a.logger.Error(fmt.Sprintf("storing nonzero token count over a month after vault was upgraded to %s", oldestVersion)) + default: + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { + a.logger.Info("storing nonzero token count") + } + } + tokenCount, err := proto.Marshal(a.currentSegment.tokenCount) + if err != nil { + return "", err + } + + a.logger.Trace("writing segment", "path", tokenPath) + err = a.view.Put(ctx, &logical.StorageEntry{ + Key: tokenPath, + Value: tokenCount, + }) + if err != nil { + return "", err + } + + return tokenPath, nil +} + +func (a *ActivityLog) saveSegmentEntitiesInternal(ctx context.Context, currentSegment segmentInfo, force bool) (string, error) { + entityPath := fmt.Sprintf("%s%d/%d", activityEntityBasePath, currentSegment.startTimestamp, currentSegment.clientSequenceNumber) for _, client := range a.currentSegment.currentClients.Clients { // Explicitly catch and throw clear error message if client ID creation and storage // results in a []byte that doesn't assert into a valid string. if !utf8.ValidString(client.ClientID) { - return fmt.Errorf("client ID %q is not a valid string:", client.ClientID) + return "", fmt.Errorf("client ID %q is not a valid string:", client.ClientID) } } - if len(a.currentSegment.currentClients.Clients) > 0 || force { - clients, err := proto.Marshal(a.currentSegment.currentClients) - if err != nil { - return err - } - - a.logger.Trace("writing segment", "path", entityPath) - err = a.view.Put(ctx, &logical.StorageEntry{ - Key: entityPath, - Value: clients, - }) - if err != nil { - return err - } + if len(currentSegment.currentClients.Clients) == 0 && !force { + return "", nil + } + clients, err := proto.Marshal(currentSegment.currentClients) + if err != nil { + return entityPath, err } - // We must still allow for the tokenCount of the current segment to - // be written to storage, since if we remove this code we will incur - // data loss for one segment's worth of TWEs. - if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 || force { - // We can get away with simply using the oldest version stored because - // the storing of versions was introduced at the same time as this code. - oldestVersion, oldestUpgradeTime, err := a.core.FindOldestVersionTimestamp() - switch { - case err != nil: - a.logger.Error(fmt.Sprintf("unable to retrieve oldest version timestamp: %s", err.Error())) - case len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 && - (oldestUpgradeTime.Add(time.Duration(trackedTWESegmentPeriod * time.Hour)).Before(a.clock.Now())): - a.logger.Error(fmt.Sprintf("storing nonzero token count over a month after vault was upgraded to %s", oldestVersion)) - default: - if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { - a.logger.Info("storing nonzero token count") - } - } - tokenCount, err := proto.Marshal(a.currentSegment.tokenCount) - if err != nil { - return err - } - - a.logger.Trace("writing segment", "path", tokenPath) - err = a.view.Put(ctx, &logical.StorageEntry{ - Key: tokenPath, - Value: tokenCount, - }) - if err != nil { - return err - } + a.logger.Trace("writing segment", "path", entityPath) + err = a.view.Put(ctx, &logical.StorageEntry{ + Key: entityPath, + Value: clients, + }) + if err != nil { + return "", err } - return nil + return entityPath, err } // parseSegmentNumberFromPath returns the segment number from a path diff --git a/vault/logical_system_activity_write_testonly.go b/vault/logical_system_activity_write_testonly.go index c1a67b8dd..45c055e1e 100644 --- a/vault/logical_system_activity_write_testonly.go +++ b/vault/logical_system_activity_write_testonly.go @@ -8,9 +8,12 @@ package vault import ( "context" "fmt" + "sync" + "time" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/activity" @@ -53,7 +56,34 @@ func (b *SystemBackend) handleActivityWriteData(ctx context.Context, request *lo if len(input.Data) == 0 { return logical.ErrorResponse("Missing required \"data\" values"), logical.ErrInvalidRequest } - return nil, nil + + numMonths := 0 + for _, month := range input.Data { + if int(month.GetMonthsAgo()) > numMonths { + numMonths = int(month.GetMonthsAgo()) + } + } + generated := newMultipleMonthsActivityClients(numMonths + 1) + for _, month := range input.Data { + err := generated.processMonth(ctx, b.Core, month) + if err != nil { + return logical.ErrorResponse("failed to process data for month %d", month.GetMonthsAgo()), err + } + } + + opts := make(map[generation.WriteOptions]struct{}, len(input.Write)) + for _, opt := range input.Write { + opts[opt] = struct{}{} + } + paths, err := generated.write(ctx, opts, b.Core.activityLog) + if err != nil { + return logical.ErrorResponse("failed to write data"), err + } + return &logical.Response{ + Data: map[string]interface{}{ + "paths": paths, + }, + }, nil } // singleMonthActivityClients holds a single month's client IDs, in the order they were seen @@ -287,6 +317,47 @@ func (m *multipleMonthsActivityClients) addRepeatedClients(monthsAgo int32, c *g return nil } +func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[generation.WriteOptions]struct{}, activityLog *ActivityLog) ([]string, error) { + now := timeutil.StartOfMonth(time.Now().UTC()) + paths := []string{} + for i, month := range m.months { + var timestamp time.Time + if i > 0 { + timestamp = timeutil.StartOfMonth(timeutil.MonthsPreviousTo(i, now)) + } else { + timestamp = now + } + segments, err := month.populateSegments() + if err != nil { + return nil, err + } + for segmentIndex, segment := range segments { + if _, ok := opts[generation.WriteOptions_WRITE_ENTITIES]; ok { + if segment == nil { + // skip the index + continue + } + entityPath, err := activityLog.saveSegmentEntitiesInternal(ctx, segmentInfo{ + startTimestamp: timestamp.Unix(), + currentClients: &activity.EntityActivityLog{Clients: segment}, + clientSequenceNumber: uint64(segmentIndex), + tokenCount: &activity.TokenCount{}, + }, true) + if err != nil { + return nil, err + } + paths = append(paths, entityPath) + } + } + } + wg := sync.WaitGroup{} + err := activityLog.refreshFromStoredLog(ctx, &wg, now) + if err != nil { + return nil, err + } + return paths, nil +} + func newMultipleMonthsActivityClients(numberOfMonths int) *multipleMonthsActivityClients { m := &multipleMonthsActivityClients{ months: make([]*singleMonthActivityClients, numberOfMonths), diff --git a/vault/logical_system_activity_write_testonly_test.go b/vault/logical_system_activity_write_testonly_test.go index f72a0e167..b9b1a939a 100644 --- a/vault/logical_system_activity_write_testonly_test.go +++ b/vault/logical_system_activity_write_testonly_test.go @@ -7,6 +7,7 @@ package vault import ( "context" + "sort" "testing" "github.com/hashicorp/vault/helper/namespace" @@ -14,6 +15,8 @@ import ( "github.com/hashicorp/vault/vault/activity" "github.com/hashicorp/vault/vault/activity/generation" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" ) // TestSystemBackend_handleActivityWriteData calls the activity log write endpoint and confirms that the inputs are @@ -24,6 +27,7 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { operation logical.Operation input map[string]interface{} wantError error + wantPaths int }{ { name: "read fails", @@ -70,6 +74,12 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { operation: logical.CreateOperation, input: map[string]interface{}{"input": `{"write":["WRITE_PRECOMPUTED_QUERIES"],"data":[{"current_month":true,"all":{"clients":[{"count":5}]}}]}`}, }, + { + name: "entities with multiple segments", + operation: logical.CreateOperation, + input: map[string]interface{}{"input": `{"write":["WRITE_ENTITIES"],"data":[{"current_month":true,"num_segments":3,"all":{"clients":[{"count":5}]}}]}`}, + wantPaths: 3, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -81,6 +91,8 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { require.Equal(t, tc.wantError, err, resp.Error()) } else { require.NoError(t, err) + paths := resp.Data["paths"].([]string) + require.Len(t, paths, tc.wantPaths) } }) } @@ -428,3 +440,137 @@ func Test_singleMonthActivityClients_populateSegments(t *testing.T) { }) } } + +// Test_multipleMonthsActivityClients_write_entities writes 4 months of data +// splitting some months across segments and using empty segments and skipped +// segments. Entities are written and then storage is queried. The test verifies +// that the correct timestamps are present in the activity log and that the correct +// segment numbers for each month contain the correct number of clients +func Test_multipleMonthsActivityClients_write_entities(t *testing.T) { + index5 := int32(5) + index4 := int32(4) + data := &generation.ActivityLogMockInput{ + Write: []generation.WriteOptions{ + generation.WriteOptions_WRITE_ENTITIES, + }, + Data: []*generation.Data{ + { + // segments: 0:[x,y], 1:[z] + Month: &generation.Data_MonthsAgo{MonthsAgo: 3}, + Clients: &generation.Data_All{All: &generation.Clients{Clients: []*generation.Client{{Count: 3}}}}, + NumSegments: 2, + }, + { + // segments: 1:[a,b,c], 2:[d,e] + Month: &generation.Data_MonthsAgo{MonthsAgo: 2}, + Clients: &generation.Data_All{All: &generation.Clients{Clients: []*generation.Client{{Count: 5}}}}, + NumSegments: 3, + SkipSegmentIndexes: []int32{0}, + }, + { + // segments: 5:[f,g] + Month: &generation.Data_MonthsAgo{MonthsAgo: 1}, + Clients: &generation.Data_Segments{ + Segments: &generation.Segments{Segments: []*generation.Segment{{ + SegmentIndex: &index5, + Clients: &generation.Clients{Clients: []*generation.Client{{Count: 2}}}, + }}}, + }, + }, + { + // segments: 1:[], 2:[], 4:[n], 5:[o] + Month: &generation.Data_CurrentMonth{}, + EmptySegmentIndexes: []int32{1, 2}, + Clients: &generation.Data_Segments{ + Segments: &generation.Segments{Segments: []*generation.Segment{ + { + SegmentIndex: &index5, + Clients: &generation.Clients{Clients: []*generation.Client{{Count: 1}}}, + }, + { + SegmentIndex: &index4, + Clients: &generation.Clients{Clients: []*generation.Client{{Count: 1}}}, + }, + }}, + }, + }, + }, + } + + core, _, _ := TestCoreUnsealed(t) + marshaled, err := protojson.Marshal(data) + require.NoError(t, err) + req := logical.TestRequest(t, logical.CreateOperation, "internal/counters/activity/write") + req.Data = map[string]interface{}{"input": string(marshaled)} + resp, err := core.systemBackend.HandleRequest(namespace.RootContext(nil), req) + require.NoError(t, err) + paths := resp.Data["paths"].([]string) + require.Len(t, paths, 9) + + times, err := core.activityLog.availableLogs(context.Background()) + require.NoError(t, err) + require.Len(t, times, 4) + + sortPaths := func(monthPaths []string) { + sort.Slice(monthPaths, func(i, j int) bool { + iVal, _ := parseSegmentNumberFromPath(monthPaths[i]) + jVal, _ := parseSegmentNumberFromPath(monthPaths[j]) + return iVal < jVal + }) + } + + month0Paths := paths[0:4] + month1Paths := paths[4:5] + month2Paths := paths[5:7] + month3Paths := paths[7:9] + sortPaths(month0Paths) + sortPaths(month1Paths) + sortPaths(month2Paths) + sortPaths(month3Paths) + entities := func(paths []string) map[int][]*activity.EntityRecord { + segments := make(map[int][]*activity.EntityRecord) + for _, path := range paths { + segmentNum, _ := parseSegmentNumberFromPath(path) + entry, err := core.activityLog.view.Get(context.Background(), path) + require.NoError(t, err) + if entry == nil { + segments[segmentNum] = []*activity.EntityRecord{} + continue + } + activities := &activity.EntityActivityLog{} + err = proto.Unmarshal(entry.Value, activities) + require.NoError(t, err) + segments[segmentNum] = activities.Clients + } + return segments + } + month0Entities := entities(month0Paths) + require.Len(t, month0Entities, 4) + require.Contains(t, month0Entities, 1) + require.Contains(t, month0Entities, 2) + require.Contains(t, month0Entities, 4) + require.Contains(t, month0Entities, 5) + require.Len(t, month0Entities[1], 0) + require.Len(t, month0Entities[2], 0) + require.Len(t, month0Entities[4], 1) + require.Len(t, month0Entities[5], 1) + + month1Entities := entities(month1Paths) + require.Len(t, month1Entities, 1) + require.Contains(t, month1Entities, 5) + require.Len(t, month1Entities[5], 2) + + month2Entities := entities(month2Paths) + require.Len(t, month2Entities, 2) + require.Contains(t, month2Entities, 1) + require.Contains(t, month2Entities, 2) + require.Len(t, month2Entities[1], 3) + require.Len(t, month2Entities[2], 2) + + month3Entities := entities(month3Paths) + require.Len(t, month3Entities, 2) + require.Contains(t, month3Entities, 0) + require.Contains(t, month3Entities, 1) + require.Len(t, month3Entities[0], 2) + require.Len(t, month3Entities[1], 1) +}