diff --git a/changelog/12820.txt b/changelog/12820.txt new file mode 100644 index 000000000..c7b92e67f --- /dev/null +++ b/changelog/12820.txt @@ -0,0 +1,3 @@ +```release-note:feature +Add ClientID to Token With Entities in Activity Log: Vault tokens without entities are now tracked with client IDs and deduplicated in the Activity Log +``` diff --git a/command/operator_usage.go b/command/operator_usage.go index a6b4b59cc..cce3c6508 100644 --- a/command/operator_usage.go +++ b/command/operator_usage.go @@ -191,8 +191,11 @@ type UsageCommandNamespace struct { type UsageResponse struct { namespacePath string entityCount int64 - tokenCount int64 - clientCount int64 + // As per 1.9, the tokenCount field will contain the distinct non-entity + // token clients instead of each individual token. + tokenCount int64 + + clientCount int64 } func jsonNumberOK(m map[string]interface{}, key string) (int64, bool) { diff --git a/sdk/logical/request.go b/sdk/logical/request.go index e683217a6..580953147 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -214,6 +214,12 @@ type Request struct { // in response headers; it's attached to the request rather than the response // because not all requests yields non-nil responses. responseState *WALState + + // ClientID is the identity of the caller. If the token is associated with an + // entity, it will be the same as the EntityID . If the token has no entity, + // this will be the sha256(sorted policies + namespace) associated with the + // client token. + ClientID string } // Clone returns a deep copy of the request by using copystructure diff --git a/vault/activity/activity_log.pb.go b/vault/activity/activity_log.pb.go index ab60ae9f4..d59d3d3e1 100644 --- a/vault/activity/activity_log.pb.go +++ b/vault/activity/activity_log.pb.go @@ -20,18 +20,23 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// EntityRecord is generated the first time an entity is active -// each month. +// EntityRecord is generated the first time an client is active +// each month. This can store clients associated with entities +// or nonEntity clients, and really is a ClientRecord, not +// specifically an EntityRecord type EntityRecord struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - EntityID string `protobuf:"bytes,1,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"` + ClientID string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` NamespaceID string `protobuf:"bytes,2,opt,name=namespace_id,json=namespaceID,proto3" json:"namespace_id,omitempty"` // using the Timestamp type would cost us an extra // 4 bytes per record to store nanoseconds. Timestamp int64 `protobuf:"varint,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + // non_entity records whether the given EntityRecord is + // for a TWE or an entity-bound token. + NonEntity bool `protobuf:"varint,4,opt,name=non_entity,json=nonEntity,proto3" json:"non_entity,omitempty"` } func (x *EntityRecord) Reset() { @@ -66,9 +71,9 @@ func (*EntityRecord) Descriptor() ([]byte, []int) { return file_vault_activity_activity_log_proto_rawDescGZIP(), []int{0} } -func (x *EntityRecord) GetEntityID() string { +func (x *EntityRecord) GetClientID() string { if x != nil { - return x.EntityID + return x.ClientID } return "" } @@ -87,6 +92,13 @@ func (x *EntityRecord) GetTimestamp() int64 { return 0 } +func (x *EntityRecord) GetNonEntity() bool { + if x != nil { + return x.NonEntity + } + return false +} + type LogFragment struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -95,8 +107,8 @@ type LogFragment struct { // hostname (or node ID?) where the fragment originated, // used for debugging. OriginatingNode string `protobuf:"bytes,1,opt,name=originating_node,json=originatingNode,proto3" json:"originating_node,omitempty"` - // active entities not yet in a log segment - Entities []*EntityRecord `protobuf:"bytes,2,rep,name=entities,proto3" json:"entities,omitempty"` + // active clients not yet in a log segment + Clients []*EntityRecord `protobuf:"bytes,2,rep,name=clients,proto3" json:"clients,omitempty"` // token counts not yet in a log segment, // indexed by namespace ID NonEntityTokens map[string]uint64 `protobuf:"bytes,3,rep,name=non_entity_tokens,json=nonEntityTokens,proto3" json:"non_entity_tokens,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"` @@ -141,9 +153,9 @@ func (x *LogFragment) GetOriginatingNode() string { return "" } -func (x *LogFragment) GetEntities() []*EntityRecord { +func (x *LogFragment) GetClients() []*EntityRecord { if x != nil { - return x.Entities + return x.Clients } return nil } @@ -155,12 +167,14 @@ func (x *LogFragment) GetNonEntityTokens() map[string]uint64 { return nil } +// This activity log stores records for both clients with entities +// and clients without entities type EntityActivityLog struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Entities []*EntityRecord `protobuf:"bytes,1,rep,name=entities,proto3" json:"entities,omitempty"` + Clients []*EntityRecord `protobuf:"bytes,1,rep,name=clients,proto3" json:"clients,omitempty"` } func (x *EntityActivityLog) Reset() { @@ -195,9 +209,9 @@ func (*EntityActivityLog) Descriptor() ([]byte, []int) { return file_vault_activity_activity_log_proto_rawDescGZIP(), []int{2} } -func (x *EntityActivityLog) GetEntities() []*EntityRecord { +func (x *EntityActivityLog) GetClients() []*EntityRecord { if x != nil { - return x.Entities + return x.Clients } return nil } @@ -292,52 +306,53 @@ var File_vault_activity_activity_log_proto protoreflect.FileDescriptor var file_vault_activity_activity_log_proto_rawDesc = []byte{ 0x0a, 0x21, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2f, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x5f, 0x6c, 0x6f, 0x67, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x22, 0x6c, 0x0a, - 0x0c, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x1b, 0x0a, - 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x6e, 0x61, - 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1c, 0x0a, - 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x22, 0x88, 0x02, 0x0a, 0x0b, + 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x22, 0x8b, 0x01, + 0x0a, 0x0c, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x1b, + 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x6e, + 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1c, + 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x1d, 0x0a, 0x0a, + 0x6e, 0x6f, 0x6e, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x09, 0x6e, 0x6f, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x22, 0x86, 0x02, 0x0a, 0x0b, 0x4c, 0x6f, 0x67, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6e, 0x67, 0x5f, 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6e, 0x67, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x32, 0x0a, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x69, - 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x61, 0x63, 0x74, 0x69, 0x76, - 0x69, 0x74, 0x79, 0x2e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, - 0x52, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x56, 0x0a, 0x11, 0x6e, 0x6f, - 0x6e, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, - 0x2e, 0x4c, 0x6f, 0x67, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x6f, 0x6e, - 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, - 0x79, 0x52, 0x0f, 0x6e, 0x6f, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x73, 0x1a, 0x42, 0x0a, 0x14, 0x4e, 0x6f, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x47, 0x0a, 0x11, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, - 0x41, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x4c, 0x6f, 0x67, 0x12, 0x32, 0x0a, 0x08, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, - 0xb4, 0x01, 0x0a, 0x0a, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x5f, - 0x0a, 0x15, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, - 0x70, 0x61, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2c, 0x2e, - 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x43, 0x6f, - 0x75, 0x6e, 0x74, 0x2e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, - 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x12, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x1a, - 0x45, 0x0a, 0x17, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, - 0x61, 0x63, 0x65, 0x49, 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x6f, 0x67, 0x46, 0x72, 0x61, - 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x2b, 0x5a, - 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, - 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x76, 0x61, 0x75, 0x6c, - 0x74, 0x2f, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x6e, 0x67, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x30, 0x0a, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, + 0x74, 0x79, 0x2e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, + 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x56, 0x0a, 0x11, 0x6e, 0x6f, 0x6e, 0x5f, + 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x61, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x4c, + 0x6f, 0x67, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x6f, 0x6e, 0x45, 0x6e, + 0x74, 0x69, 0x74, 0x79, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, + 0x0f, 0x6e, 0x6f, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, + 0x1a, 0x42, 0x0a, 0x14, 0x4e, 0x6f, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x3a, 0x02, 0x38, 0x01, 0x22, 0x45, 0x0a, 0x11, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x41, 0x63, + 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x4c, 0x6f, 0x67, 0x12, 0x30, 0x0a, 0x07, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x61, 0x63, 0x74, + 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x45, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x52, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0a, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x5f, 0x0a, 0x15, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2c, 0x2e, 0x61, 0x63, 0x74, 0x69, + 0x76, 0x69, 0x74, 0x79, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x2e, + 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, + 0x49, 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x12, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, + 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x1a, 0x45, 0x0a, 0x17, 0x43, + 0x6f, 0x75, 0x6e, 0x74, 0x42, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, + 0x64, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, + 0x38, 0x01, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x6f, 0x67, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x2b, 0x5a, 0x29, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, + 0x70, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x61, 0x63, + 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -363,9 +378,9 @@ var file_vault_activity_activity_log_proto_goTypes = []interface{}{ nil, // 6: activity.TokenCount.CountByNamespaceIDEntry } var file_vault_activity_activity_log_proto_depIDxs = []int32{ - 0, // 0: activity.LogFragment.entities:type_name -> activity.EntityRecord + 0, // 0: activity.LogFragment.clients:type_name -> activity.EntityRecord 5, // 1: activity.LogFragment.non_entity_tokens:type_name -> activity.LogFragment.NonEntityTokensEntry - 0, // 2: activity.EntityActivityLog.entities:type_name -> activity.EntityRecord + 0, // 2: activity.EntityActivityLog.clients:type_name -> activity.EntityRecord 6, // 3: activity.TokenCount.count_by_namespace_id:type_name -> activity.TokenCount.CountByNamespaceIDEntry 4, // [4:4] is the sub-list for method output_type 4, // [4:4] is the sub-list for method input_type diff --git a/vault/activity/activity_log.proto b/vault/activity/activity_log.proto index 03aaed577..e785627f3 100644 --- a/vault/activity/activity_log.proto +++ b/vault/activity/activity_log.proto @@ -4,14 +4,19 @@ option go_package = "github.com/hashicorp/vault/vault/activity"; package activity; -// EntityRecord is generated the first time an entity is active -// each month. +// EntityRecord is generated the first time an client is active + // each month. This can store clients associated with entities + // or nonEntity clients, and really is a ClientRecord, not + // specifically an EntityRecord message EntityRecord { - string entity_id = 1; + string client_id = 1; string namespace_id = 2; // using the Timestamp type would cost us an extra // 4 bytes per record to store nanoseconds. int64 timestamp = 3; + // non_entity records whether the given EntityRecord is + // for a TWE or an entity-bound token. + bool non_entity = 4; } message LogFragment { @@ -19,16 +24,18 @@ message LogFragment { // used for debugging. string originating_node = 1; - // active entities not yet in a log segment - repeated EntityRecord entities = 2; + // active clients not yet in a log segment + repeated EntityRecord clients = 2; // token counts not yet in a log segment, // indexed by namespace ID map non_entity_tokens = 3; } +// This activity log stores records for both clients with entities +// and clients without entities message EntityActivityLog { - repeated EntityRecord entities = 1; + repeated EntityRecord clients = 1; } message TokenCount { diff --git a/vault/activity_log.go b/vault/activity_log.go index 33f9c24be..fb866640a 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -2,6 +2,8 @@ package vault import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -11,6 +13,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "github.com/golang/protobuf/proto" log "github.com/hashicorp/go-hclog" @@ -19,6 +22,7 @@ import ( "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/activity" + "github.com/mitchellh/copystructure" ) const ( @@ -58,18 +62,38 @@ const ( // standby fragment before sending it to the active node. // Estimates as 8KiB / 64 bytes = 128 activityFragmentStandbyCapacity = 128 + + // Delimiter between the string fields used to generate a client + // ID for tokens without entities. This is the 0 character, which + // is a non-printable string. Please see unicode.IsPrint for details. + clientIDTWEDelimiter = rune('\x00') + + // Delimiter between each policy in the sorted policies used to + // generate a client ID for tokens without entities. This is the 127 + // character, which is a non-printable string. Please see unicode.IsPrint + // for details. + sortedPoliciesTWEDelimiter = rune('\x7F') + + // trackedTWESegmentPeriod is a time period of a little over a month, and represents + // the amount of time that needs to pass after a 1.9 or later upgrade to result in + // all fragments and segments no longer storing token counts in the directtokens + // storage path. + trackedTWESegmentPeriod = 35 * 24 ) type segmentInfo struct { startTimestamp int64 - currentEntities *activity.EntityActivityLog - tokenCount *activity.TokenCount - entitySequenceNumber uint64 + currentClients *activity.EntityActivityLog + clientSequenceNumber uint64 + // DEPRECATED + // This field is needed for backward compatibility with fragments + // and segments created with vault versions before 1.9. + tokenCount *activity.TokenCount } type clients struct { - distinctEntities uint64 - nonEntityTokens uint64 + distinctEntities uint64 + distinctNonEntities uint64 } // ActivityLog tracks unique entity counts and non-entity token counts. @@ -84,7 +108,7 @@ type ActivityLog struct { // Acquire "l" before fragmentLock if both must be held. l sync.RWMutex - // fragmentLock protects enable, activeEntities, fragment, standbyFragmentsReceived + // fragmentLock protects enable, activeClients, fragment, standbyFragmentsReceived fragmentLock sync.RWMutex // enabled indicates if the activity log is enabled for this cluster. @@ -145,15 +169,16 @@ type ActivityLog struct { // for testing: is config currently being invalidated. protected by l configInvalidationInProgress bool - // entityTracker tracks active entities this month. Protected by fragmentLock. - entityTracker *EntityTracker + // clientTracker tracks active clients this month. Protected by fragmentLock. + clientTracker *ClientTracker } -type EntityTracker struct { - // All known active entities this month; use fragmentLock read-locked +type ClientTracker struct { + // All known active clients this month; use fragmentLock read-locked // to check whether it already exists. - activeEntities map[string]struct{} - entityCountByNamespaceID map[string]uint64 + activeClients map[string]struct{} + entityCountByNamespaceID map[string]uint64 + nonEntityCountByNamespaceID map[string]uint64 } // These non-persistent configuration options allow us to disable @@ -185,19 +210,23 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me sendCh: make(chan struct{}, 1), // buffered so it can be triggered by fragment size writeCh: make(chan struct{}, 1), // same for full segment doneCh: make(chan struct{}, 1), - entityTracker: &EntityTracker{ - activeEntities: make(map[string]struct{}), - entityCountByNamespaceID: make(map[string]uint64), + clientTracker: &ClientTracker{ + activeClients: make(map[string]struct{}), + entityCountByNamespaceID: make(map[string]uint64), + nonEntityCountByNamespaceID: make(map[string]uint64), }, currentSegment: segmentInfo{ startTimestamp: 0, - currentEntities: &activity.EntityActivityLog{ - Entities: make([]*activity.EntityRecord, 0), + currentClients: &activity.EntityActivityLog{ + Clients: make([]*activity.EntityRecord, 0), }, + // tokenCount is deprecated, but must still exist for the current segment + // so the fragment that was using TWEs before the 1.9 changes + // can be flushed to the current segment. tokenCount: &activity.TokenCount{ CountByNamespaceID: make(map[string]uint64), }, - entitySequenceNumber: 0, + clientSequenceNumber: 0, }, standbyFragmentsReceived: make([]*activity.LogFragment, 0), } @@ -251,7 +280,7 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for // Measure the current fragment if localFragment != nil { a.metrics.IncrCounterWithLabels([]string{"core", "activity", "fragment_size"}, - float32(len(localFragment.Entities)), + float32(len(localFragment.Clients)), []metricsutil.Label{ {"type", "entity"}, }) @@ -269,15 +298,24 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for if f == nil { continue } - for _, e := range f.Entities { + for _, e := range f.Clients { // We could sort by timestamp to see which is first. // We'll ignore that; the order of the append above means // that we choose entries in localFragment over those // from standby nodes. - newEntities[e.EntityID] = e + newEntities[e.ClientID] = e saveChanges = true } + // As of 1.9, a fragment should no longer have any NonEntityTokens. However + // in order to not lose any information about the current segment during the + // month when the client upgrades to 1.9, we must retain this functionality. for ns, val := range f.NonEntityTokens { + // We track these pre-1.9 values in the old location, which is + // a.currentSegment.tokenCount, as opposed to the counter that stores tokens + // without entities that have client IDs, namely + // a.clientTracker.nonEntityCountByNamespaceID. This preserves backward + // compatibility for the precomputedQueryWorkers and the segment storing + // logic. a.currentSegment.tokenCount.CountByNamespaceID[ns] += val saveChanges = true } @@ -288,49 +326,49 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for } // Will all new entities fit? If not, roll over to a new segment. - available := activitySegmentEntityCapacity - len(a.currentSegment.currentEntities.Entities) + available := activitySegmentEntityCapacity - len(a.currentSegment.currentClients.Clients) remaining := available - len(newEntities) excess := 0 if remaining < 0 { excess = -remaining } - segmentEntities := a.currentSegment.currentEntities.Entities - excessEntities := make([]*activity.EntityRecord, 0, excess) + segmentClients := a.currentSegment.currentClients.Clients + excessClients := make([]*activity.EntityRecord, 0, excess) for _, record := range newEntities { if available > 0 { - segmentEntities = append(segmentEntities, record) + segmentClients = append(segmentClients, record) available -= 1 } else { - excessEntities = append(excessEntities, record) + excessClients = append(excessClients, record) } } - a.currentSegment.currentEntities.Entities = segmentEntities + a.currentSegment.currentClients.Clients = segmentClients err := a.saveCurrentSegmentInternal(ctx, force) if err != nil { // The current fragment(s) have already been placed into the in-memory - // segment, but we may lose any excess (in excessEntities). + // segment, but we may lose any excess (in excessClients). // There isn't a good way to unwind the transaction on failure, // so we may just lose some records. return err } if available <= 0 { - if a.currentSegment.entitySequenceNumber >= activityLogMaxSegmentPerMonth { + if a.currentSegment.clientSequenceNumber >= activityLogMaxSegmentPerMonth { // Cannot send as Warn because it will repeat too often, // and disabling/renabling would be complicated. - a.logger.Trace("too many segments in current month", "dropped", len(excessEntities)) + a.logger.Trace("too many segments in current month", "dropped", len(excessClients)) return nil } // Rotate to next segment - a.currentSegment.entitySequenceNumber += 1 - if len(excessEntities) > activitySegmentEntityCapacity { - a.logger.Warn("too many new active entities, dropping tail", "entities", len(excessEntities)) - excessEntities = excessEntities[:activitySegmentEntityCapacity] + a.currentSegment.clientSequenceNumber += 1 + if len(excessClients) > activitySegmentEntityCapacity { + a.logger.Warn("too many new active entities, dropping tail", "entities", len(excessClients)) + excessClients = excessClients[:activitySegmentEntityCapacity] } - a.currentSegment.currentEntities.Entities = excessEntities + a.currentSegment.currentClients.Clients = excessClients err := a.saveCurrentSegmentInternal(ctx, force) if err != nil { return err @@ -341,12 +379,20 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for // :force: forces a save of tokens/entities even if the in-memory log is empty func (a *ActivityLog) saveCurrentSegmentInternal(ctx context.Context, force bool) error { - entityPath := fmt.Sprintf("log/entity/%d/%d", a.currentSegment.startTimestamp, a.currentSegment.entitySequenceNumber) + entityPath := fmt.Sprintf("log/entity/%d/%d", a.currentSegment.startTimestamp, a.currentSegment.clientSequenceNumber) // RFC (VLT-120) defines this as 1-indexed, but it should be 0-indexed tokenPath := fmt.Sprintf("log/directtokens/%d/0", a.currentSegment.startTimestamp) - if len(a.currentSegment.currentEntities.Entities) > 0 || force { - entities, err := proto.Marshal(a.currentSegment.currentEntities) + for _, client := range a.currentSegment.currentClients.Clients { + // Explicitly catch and throw clear error message if client ID creation and storage + // results in a []byte that doesn't assert into a valid string. + if !utf8.ValidString(client.ClientID) { + return fmt.Errorf("client ID %q is not a valid string:", client.ClientID) + } + } + + if len(a.currentSegment.currentClients.Clients) > 0 || force { + entities, err := proto.Marshal(a.currentSegment.currentClients) if err != nil { return err } @@ -361,7 +407,22 @@ func (a *ActivityLog) saveCurrentSegmentInternal(ctx context.Context, force bool } } + // We must still allow for the tokenCount of the current segment to + // be written to storage, since if we remove this code we will incur + // data loss for one segment's worth of TWEs. if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 || force { + oldestVersion, oldestUpgradeTime, err := a.core.FindOldestVersionTimestamp() + switch { + case err != nil: + a.logger.Error(fmt.Sprintf("unable to retrieve oldest version timestamp: %s", err.Error())) + case len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 && + (oldestUpgradeTime.Add(time.Duration(trackedTWESegmentPeriod * time.Hour)).Before(time.Now())): + a.logger.Error(fmt.Sprintf("storing nonzero token count over a month after vault was upgraded to %s", oldestVersion)) + default: + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { + a.logger.Info("storing nonzero token count") + } + } tokenCount, err := proto.Marshal(a.currentSegment.tokenCount) if err != nil { return err @@ -376,7 +437,6 @@ func (a *ActivityLog) saveCurrentSegmentInternal(ctx context.Context, force bool return err } } - return nil } @@ -511,7 +571,7 @@ func (a *ActivityLog) WalkTokenSegments(ctx context.Context, return err } if raw == nil { - a.logger.Warn("expected token segment not found", "startTime", startTime, "segment", path) + a.logger.Trace("no tokens without entities stored without tracking", "startTime", startTime, "segment", path) continue } out := &activity.TokenCount{} @@ -544,8 +604,8 @@ func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time // Handle the (unlikely) case where the end of the month has been reached while background loading. // Or the feature has been disabled. if a.enabled && startTime.Unix() == a.currentSegment.startTimestamp { - for _, ent := range out.Entities { - a.entityTracker.addEntity(ent) + for _, ent := range out.Clients { + a.clientTracker.addClient(ent) } } a.fragmentLock.Unlock() @@ -554,10 +614,10 @@ func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time return nil } -// loadCurrentEntitySegment loads the most recent segment (for "this month") into memory -// (to append new entries), and to the activeEntities to avoid duplication +// loadCurrentClientSegment loads the most recent segment (for "this month") into memory +// (to append new entries), and to the activeClients to avoid duplication // call with fragmentLock and l held -func (a *ActivityLog) loadCurrentEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64) error { +func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime time.Time, sequenceNum uint64) error { path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) data, err := a.view.Get(ctx, path) if err != nil { @@ -573,19 +633,19 @@ func (a *ActivityLog) loadCurrentEntitySegment(ctx context.Context, startTime ti if !a.core.perfStandby { a.currentSegment = segmentInfo{ startTimestamp: startTime.Unix(), - currentEntities: &activity.EntityActivityLog{ - Entities: out.Entities, + currentClients: &activity.EntityActivityLog{ + Clients: out.Clients, }, tokenCount: a.currentSegment.tokenCount, - entitySequenceNumber: sequenceNum, + clientSequenceNumber: sequenceNum, } } else { // populate this for edge case checking (if end of month passes while background loading on standby) a.currentSegment.startTimestamp = startTime.Unix() } - for _, ent := range out.Entities { - a.entityTracker.addEntity(ent) + for _, ent := range out.Clients { + a.clientTracker.addClient(ent) } return nil @@ -635,6 +695,10 @@ func (a *ActivityLog) loadTokenCount(ctx context.Context, startTime time.Time) e if out.CountByNamespaceID == nil { out.CountByNamespaceID = make(map[string]uint64) } + + // We must load the tokenCount of the current segment into the activity log + // so that TWEs counted before the introduction of a client ID for TWEs are + // still reported in the partial client counts. a.currentSegment.tokenCount = out return nil @@ -688,17 +752,22 @@ func (a *ActivityLog) newSegmentAtGivenTime(t time.Time) { // Should be called with fragmentLock and l held. func (a *ActivityLog) resetCurrentLog() { a.currentSegment.startTimestamp = 0 - a.currentSegment.currentEntities = &activity.EntityActivityLog{ - Entities: make([]*activity.EntityRecord, 0), + a.currentSegment.currentClients = &activity.EntityActivityLog{ + Clients: make([]*activity.EntityRecord, 0), } + + // We must still initialize the tokenCount to recieve tokenCounts from fragments + // during the month where customers upgrade to 1.9 a.currentSegment.tokenCount = &activity.TokenCount{ CountByNamespaceID: make(map[string]uint64), } - a.currentSegment.entitySequenceNumber = 0 + + a.currentSegment.clientSequenceNumber = 0 a.fragment = nil - a.entityTracker.activeEntities = make(map[string]struct{}) - a.entityTracker.entityCountByNamespaceID = make(map[string]uint64) + a.clientTracker.activeClients = make(map[string]struct{}) + a.clientTracker.entityCountByNamespaceID = make(map[string]uint64) + a.clientTracker.nonEntityCountByNamespaceID = make(map[string]uint64) a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) } @@ -814,7 +883,9 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro return nil } - // load token counts from storage into memory + // load token counts from storage into memory. As of 1.9, this functionality + // is still required since without it, we would lose replicated TWE counts for the + // current segment. if !a.core.perfStandby { err = a.loadTokenCount(ctx, mostRecent) if err != nil { @@ -832,7 +903,7 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro return nil } - err = a.loadCurrentEntitySegment(ctx, mostRecent, lastSegment) + err = a.loadCurrentClientSegment(ctx, mostRecent, lastSegment) if err != nil || lastSegment == 0 { return err } @@ -1000,18 +1071,12 @@ func (c *Core) setupActivityLog(ctx context.Context, wg *sync.WaitGroup) error { }() } - // Link the token store to this core - c.tokenStore.SetActivityLog(manager) - return nil } // stopActivityLog removes the ActivityLog from Core // and frees any resources. func (c *Core) stopActivityLog() { - if c.tokenStore != nil { - c.tokenStore.SetActivityLog(nil) - } // preSeal may run before startActivityLog got a chance to complete. if c.activityLog != nil { @@ -1109,8 +1174,9 @@ func (a *ActivityLog) perfStandbyFragmentWorker() { // clear active entity set a.fragmentLock.Lock() - a.entityTracker.activeEntities = make(map[string]struct{}) - a.entityTracker.entityCountByNamespaceID = make(map[string]uint64) + a.clientTracker.activeClients = make(map[string]struct{}) + a.clientTracker.entityCountByNamespaceID = make(map[string]uint64) + a.clientTracker.nonEntityCountByNamespaceID = make(map[string]uint64) a.fragmentLock.Unlock() // Set timer for next month. @@ -1288,17 +1354,21 @@ func (c *Core) ResetActivityLog() []*activity.LogFragment { return allFragments } -// AddEntityToFragment checks an entity ID for uniqueness and +func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, timestamp int64) { + a.AddClientToFragment(entityID, namespaceID, timestamp, false) +} + +// AddClientToFragment checks a client ID for uniqueness and // if not already present, adds it to the current fragment. // The timestamp is a Unix timestamp *without* nanoseconds, as that // is what token.CreationTime uses. -func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, timestamp int64) { +func (a *ActivityLog) AddClientToFragment(clientID string, namespaceID string, timestamp int64, isTWE bool) { // Check whether entity ID already recorded var present bool a.fragmentLock.RLock() if a.enabled { - _, present = a.entityTracker.activeEntities[entityID] + _, present = a.clientTracker.activeClients[clientID] } else { present = true } @@ -1312,33 +1382,28 @@ func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, t defer a.fragmentLock.Unlock() // Re-check entity ID after re-acquiring lock - _, present = a.entityTracker.activeEntities[entityID] + _, present = a.clientTracker.activeClients[clientID] if present { return } a.createCurrentFragment() - entityRecord := &activity.EntityRecord{ - EntityID: entityID, + clientRecord := &activity.EntityRecord{ + ClientID: clientID, NamespaceID: namespaceID, Timestamp: timestamp, } - a.fragment.Entities = append(a.fragment.Entities, entityRecord) - a.entityTracker.addEntity(entityRecord) -} -func (a *ActivityLog) AddTokenToFragment(namespaceID string) { - a.fragmentLock.Lock() - defer a.fragmentLock.Unlock() - - if !a.enabled { - return + // Track whether the clientID corresponds to a token without an entity or not. + // This field is backward compatible, as the default is 0, so records created + // from pre-1.9 activityLog code will automatically be marked as having an entity. + if isTWE { + clientRecord.NonEntity = true } - a.createCurrentFragment() - - a.fragment.NonEntityTokens[namespaceID] += 1 + a.fragment.Clients = append(a.fragment.Clients, clientRecord) + a.clientTracker.addClient(clientRecord) } // Create the current fragment if it doesn't already exist. @@ -1347,7 +1412,7 @@ func (a *ActivityLog) createCurrentFragment() { if a.fragment == nil { a.fragment = &activity.LogFragment{ OriginatingNode: a.nodeID, - Entities: make([]*activity.EntityRecord, 0, 120), + Clients: make([]*activity.EntityRecord, 0, 120), NonEntityTokens: make(map[string]uint64), } a.fragmentCreation = time.Now().UTC() @@ -1368,8 +1433,8 @@ func (a *ActivityLog) receivedFragment(fragment *activity.LogFragment) { return } - for _, e := range fragment.Entities { - a.entityTracker.addEntity(e) + for _, e := range fragment.Clients { + a.clientTracker.addClient(e) } a.standbyFragmentsReceived = append(a.standbyFragmentsReceived, fragment) @@ -1518,19 +1583,74 @@ func (a *ActivityLog) loadConfigOrDefault(ctx context.Context) (activityConfig, return config, nil } -// HandleTokenCreation adds the TokenEntry to the current fragment of the activity log. -// This currently occurs on token creation (for tokens without entities) -// or token usage (for tokens associated with entities) -func (a *ActivityLog) HandleTokenCreation(entry *logical.TokenEntry) { - // enabled state is checked in both of these functions, - // because we have to grab a mutex there anyway. - if entry.EntityID != "" { - a.AddEntityToFragment(entry.EntityID, entry.NamespaceID, entry.CreationTime) - } else { - if !IsWrappingToken(entry) { - a.AddTokenToFragment(entry.NamespaceID) - } +// HandleTokenUsage adds the TokenEntry to the current fragment of the activity log. +// This currently occurs on token usage only. +func (a *ActivityLog) HandleTokenUsage(entry *logical.TokenEntry) { + // First, check if a is enabled, so as to avoid the cost of creating an ID for + // tokens without entities in the case where it not. + a.fragmentLock.RLock() + if !a.enabled { + a.fragmentLock.RUnlock() + return } + a.fragmentLock.RUnlock() + + // Do not count wrapping tokens in client count + if IsWrappingToken(entry) { + return + } + + // Do not count root tokens in client count. This includes generated root tokens + // as well. + if len(entry.Policies) == 1 && entry.Policies[0] == "root" { + return + } + + // Parse an entry's client ID and add it to the activity log + clientID, isTWE := a.CreateClientID(entry) + a.AddClientToFragment(clientID, entry.NamespaceID, entry.CreationTime, isTWE) +} + +// CreateClientID returns the client ID, and a boolean which is false if the clientID +// has an entity, and true otherwise +func (a *ActivityLog) CreateClientID(entry *logical.TokenEntry) (string, bool) { + var clientIDInputBuilder strings.Builder + + // if entry has an associated entity ID, return it + if entry.EntityID != "" { + return entry.EntityID, false + } + + // The entry is associated with a TWE (token without entity). In this case + // we must create a client ID by calculating the following formula: + // clientID = SHA256(sorted policies + namespace) + + // Step 1: Copy entry policies to a new struct + sortedPolicies := make([]string, len(entry.Policies)) + copy(sortedPolicies, entry.Policies) + + // Step 2: Sort and join copied policies + sort.Strings(sortedPolicies) + for _, pol := range sortedPolicies { + clientIDInputBuilder.WriteRune(sortedPoliciesTWEDelimiter) + clientIDInputBuilder.WriteString(pol) + } + + // Step 3: Add namespace ID + clientIDInputBuilder.WriteRune(clientIDTWEDelimiter) + clientIDInputBuilder.WriteString(entry.NamespaceID) + + if clientIDInputBuilder.Len() == 0 { + a.logger.Error("vault token with no entity ID, policies, or namespace was recorded " + + "in the activity log") + return "", true + } + // Step 4: Remove the first character in the string, as it's an unnecessary delimiter + clientIDInput := clientIDInputBuilder.String()[1:] + + // Step 5: Hash the sum + hashed := sha256.Sum256([]byte(clientIDInput)) + return base64.URLEncoding.EncodeToString(hashed[:]), true } func (a *ActivityLog) namespaceToLabel(ctx context.Context, nsID string) string { @@ -1626,24 +1746,31 @@ func (a *ActivityLog) precomputedQueryWorker() error { type NamespaceCounts struct { // entityID -> present Entities map[string]struct{} - // count + // count. This exists for backward compatibility Tokens uint64 + // clientID -> present + NonEntities map[string]struct{} } byNamespace := make(map[string]*NamespaceCounts) createNs := func(namespaceID string) { if _, namespacePresent := byNamespace[namespaceID]; !namespacePresent { byNamespace[namespaceID] = &NamespaceCounts{ - Entities: make(map[string]struct{}), - Tokens: 0, + Entities: make(map[string]struct{}), + Tokens: 0, + NonEntities: make(map[string]struct{}), } } } walkEntities := func(l *activity.EntityActivityLog) { - for _, e := range l.Entities { + for _, e := range l.Clients { createNs(e.NamespaceID) - byNamespace[e.NamespaceID].Entities[e.EntityID] = struct{}{} + if e.NonEntity == true { + byNamespace[e.NamespaceID].NonEntities[e.ClientID] = struct{}{} + } else { + byNamespace[e.NamespaceID].Entities[e.ClientID] = struct{}{} + } } } walkTokens := func(l *activity.TokenCount) { @@ -1689,7 +1816,7 @@ func (a *ActivityLog) precomputedQueryWorker() error { pq.Namespaces = append(pq.Namespaces, &activity.NamespaceRecord{ NamespaceID: nsID, Entities: uint64(len(counts.Entities)), - NonEntityTokens: counts.Tokens, + NonEntityTokens: counts.Tokens + uint64(len(counts.NonEntities)), }) // If this is the most recent month, or the start of the reporting period, output @@ -1702,6 +1829,13 @@ func (a *ActivityLog) precomputedQueryWorker() error { {Name: "namespace", Value: a.namespaceToLabel(ctx, nsID)}, }, ) + a.metrics.SetGaugeWithLabels( + []string{"identity", "nonentity", "active", "monthly"}, + float32(len(counts.NonEntities))+float32(counts.Tokens), + []metricsutil.Label{ + {Name: "namespace", Value: a.namespaceToLabel(ctx, nsID)}, + }, + ) } else if startTime == activePeriodStart { a.metrics.SetGaugeWithLabels( []string{"identity", "entity", "active", "reporting_period"}, @@ -1710,6 +1844,13 @@ func (a *ActivityLog) precomputedQueryWorker() error { {Name: "namespace", Value: a.namespaceToLabel(ctx, nsID)}, }, ) + a.metrics.SetGaugeWithLabels( + []string{"identity", "nonentity", "active", "reporting_period"}, + float32(len(counts.NonEntities))+float32(counts.Tokens), + []metricsutil.Label{ + {Name: "namespace", Value: a.namespaceToLabel(ctx, nsID)}, + }, + ) } } @@ -1786,7 +1927,7 @@ func (a *ActivityLog) PartialMonthMetrics(ctx context.Context) ([]metricsutil.Ga // Empty list return []metricsutil.GaugeLabelValues{}, nil } - count := len(a.entityTracker.activeEntities) + count := len(a.clientTracker.activeClients) return []metricsutil.GaugeLabelValues{ { @@ -1820,9 +1961,17 @@ func (a *ActivityLog) partialMonthClientCount(ctx context.Context) (map[string]i responseData := make(map[string]interface{}) totalEntities := 0 totalTokens := 0 - - clientCountTable := createClientCountTable(a.entityTracker.entityCountByNamespaceID, a.currentSegment.tokenCount.CountByNamespaceID) - + nonEntityTokensMapInterface, err := copystructure.Copy(a.clientTracker.nonEntityCountByNamespaceID) + if err != nil { + return nil, fmt.Errorf("error making deep copy of nonEntityCounts: %+w", err) + } + nonEntityTokensMap := nonEntityTokensMapInterface.(map[string]uint64) + // Merge the tokenCounts created pre-clientID with the newly counted + // clientID tokens, if tokenCounts exist. + for nsID, count := range a.currentSegment.tokenCount.CountByNamespaceID { + nonEntityTokensMap[nsID] += count + } + clientCountTable := createClientCountTable(a.clientTracker.entityCountByNamespaceID, nonEntityTokensMap) queryNS, err := namespace.FromContext(ctx) if err != nil { return nil, err @@ -1849,13 +1998,13 @@ func (a *ActivityLog) partialMonthClientCount(ctx context.Context) (map[string]i NamespacePath: displayPath, Counts: ClientCountResponse{ DistinctEntities: int(clients.distinctEntities), - NonEntityTokens: int(clients.nonEntityTokens), - Clients: int(clients.distinctEntities + clients.nonEntityTokens), + NonEntityTokens: int(clients.distinctNonEntities), + Clients: int(clients.distinctEntities + clients.distinctNonEntities), }, }) totalEntities += int(clients.distinctEntities) - totalTokens += int(clients.nonEntityTokens) + totalTokens += int(clients.distinctNonEntities) } } @@ -1872,30 +2021,33 @@ func (a *ActivityLog) partialMonthClientCount(ctx context.Context) (map[string]i return responseData, nil } -//createClientCountTable maps the entitycount and token count to the namespace id -func createClientCountTable(entityMap map[string]uint64, tokenMap map[string]uint64) map[string]*clients { +// createClientCountTable maps the entitycount and token count to the namespace id. +func createClientCountTable(entityMap map[string]uint64, nonEntityMap map[string]uint64) map[string]*clients { clientCountTable := make(map[string]*clients) for nsID, count := range entityMap { if _, ok := clientCountTable[nsID]; !ok { - clientCountTable[nsID] = &clients{distinctEntities: 0, nonEntityTokens: 0} + clientCountTable[nsID] = &clients{distinctEntities: 0, distinctNonEntities: 0} } clientCountTable[nsID].distinctEntities += count } - for nsID, count := range tokenMap { + for nsID, count := range nonEntityMap { if _, ok := clientCountTable[nsID]; !ok { - clientCountTable[nsID] = &clients{distinctEntities: 0, nonEntityTokens: 0} + clientCountTable[nsID] = &clients{distinctEntities: 0, distinctNonEntities: 0} } - clientCountTable[nsID].nonEntityTokens += count - + clientCountTable[nsID].distinctNonEntities += count } return clientCountTable } -func (et *EntityTracker) addEntity(e *activity.EntityRecord) { - if _, ok := et.activeEntities[e.EntityID]; !ok { - et.activeEntities[e.EntityID] = struct{}{} - et.entityCountByNamespaceID[e.NamespaceID] += 1 +func (ct *ClientTracker) addClient(e *activity.EntityRecord) { + if _, ok := ct.activeClients[e.ClientID]; !ok { + ct.activeClients[e.ClientID] = struct{}{} + if e.NonEntity == true { + ct.nonEntityCountByNamespaceID[e.NamespaceID] += 1 + } else { + ct.entityCountByNamespaceID[e.NamespaceID] += 1 + } } } diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index 0ba97d29d..a7fbe9d54 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -2,6 +2,8 @@ package vault import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -52,7 +54,7 @@ func TestActivityLog_Creation(t *testing.T) { t.Errorf("mismatched node ID, %q vs %q", a.fragment.OriginatingNode, a.nodeID) } - if a.fragment.Entities == nil { + if a.fragment.Clients == nil { t.Fatal("no fragment entity slice") } @@ -60,13 +62,13 @@ func TestActivityLog_Creation(t *testing.T) { t.Fatal("no fragment token map") } - if len(a.fragment.Entities) != 1 { - t.Fatalf("wrong number of entities %v", len(a.fragment.Entities)) + if len(a.fragment.Clients) != 1 { + t.Fatalf("wrong number of entities %v", len(a.fragment.Clients)) } - er := a.fragment.Entities[0] - if er.EntityID != entity_id { - t.Errorf("mimatched entity ID, %q vs %q", er.EntityID, entity_id) + er := a.fragment.Clients[0] + if er.ClientID != entity_id { + t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, entity_id) } if er.NamespaceID != namespace_id { t.Errorf("mimatched namespace ID, %q vs %q", er.NamespaceID, namespace_id) @@ -112,7 +114,7 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { a.fragmentLock.Unlock() const namespace_id = "ns123" - a.HandleTokenCreation(&logical.TokenEntry{ + a.HandleTokenUsage(&logical.TokenEntry{ Path: "test", Policies: []string{responseWrappingPolicyName}, CreationTime: time.Now().Unix(), @@ -126,7 +128,7 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { } a.fragmentLock.Unlock() - a.HandleTokenCreation(&logical.TokenEntry{ + a.HandleTokenUsage(&logical.TokenEntry{ Path: "test", Policies: []string{controlGroupPolicyName}, CreationTime: time.Now().Unix(), @@ -144,12 +146,12 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { func checkExpectedEntitiesInMap(t *testing.T, a *ActivityLog, entityIDs []string) { t.Helper() - activeEntities := a.core.GetActiveEntities() - if len(activeEntities) != len(entityIDs) { - t.Fatalf("mismatched number of entities, expected %v got %v", len(entityIDs), activeEntities) + activeClients := a.core.GetActiveClients() + if len(activeClients) != len(entityIDs) { + t.Fatalf("mismatched number of entities, expected %v got %v", len(entityIDs), activeClients) } for _, e := range entityIDs { - if _, present := activeEntities[e]; !present { + if _, present := activeClients[e]; !present { t.Errorf("entity ID %q is missing", e) } } @@ -176,11 +178,11 @@ func TestActivityLog_UniqueEntities(t *testing.T) { t.Fatal("no current fragment") } - if len(a.fragment.Entities) != 2 { - t.Fatalf("number of entities is %v", len(a.fragment.Entities)) + if len(a.fragment.Clients) != 2 { + t.Fatalf("number of entities is %v", len(a.fragment.Clients)) } - for i, e := range a.fragment.Entities { + for i, e := range a.fragment.Clients { expectedID := id1 expectedTime := t1.Unix() expectedNS := "root" @@ -188,8 +190,8 @@ func TestActivityLog_UniqueEntities(t *testing.T) { expectedID = id2 expectedTime = t2.Unix() } - if e.EntityID != expectedID { - t.Errorf("%v: expected %q, got %q", i, expectedID, e.EntityID) + if e.ClientID != expectedID { + t.Errorf("%v: expected %q, got %q", i, expectedID, e.ClientID) } if e.NamespaceID != expectedNS { t.Errorf("%v: expected %q, got %q", i, expectedNS, e.NamespaceID) @@ -202,6 +204,17 @@ func TestActivityLog_UniqueEntities(t *testing.T) { checkExpectedEntitiesInMap(t, a, []string{id1, id2}) } +func readSegmentFromStorageNil(t *testing.T, c *Core, path string) { + t.Helper() + logSegment, err := c.barrier.Get(context.Background(), path) + if err != nil { + t.Fatal(err) + } + if logSegment != nil { + t.Fatalf("expected non-nil log segment at %q", path) + } +} + func readSegmentFromStorage(t *testing.T, c *Core, path string) *logical.StorageEntry { t.Helper() logSegment, err := c.barrier.Get(context.Background(), path) @@ -229,15 +242,15 @@ func expectMissingSegment(t *testing.T, c *Core, path string) { func expectedEntityIDs(t *testing.T, out *activity.EntityActivityLog, ids []string) { t.Helper() - if len(out.Entities) != len(ids) { - t.Fatalf("entity log expected length %v, actual %v", len(ids), len(out.Entities)) + if len(out.Clients) != len(ids) { + t.Fatalf("entity log expected length %v, actual %v", len(ids), len(out.Clients)) } // Double loop, OK for small cases for _, id := range ids { found := false - for _, e := range out.Entities { - if e.EntityID == id { + for _, e := range out.Clients { + if e.ClientID == id { found = true break } @@ -271,8 +284,8 @@ func TestActivityLog_SaveTokensToStorage(t *testing.T) { t.Errorf("fragment was not reset after write to storage") } - protoSegment := readSegmentFromStorage(t, core, path) out := &activity.TokenCount{} + protoSegment := readSegmentFromStorage(t, core, path) err = proto.Unmarshal(protoSegment.Value, out) if err != nil { t.Fatalf("could not unmarshal protobuf: %v", err) @@ -329,6 +342,75 @@ func TestActivityLog_SaveTokensToStorage(t *testing.T) { } } +// TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount ensures that +// a new fragment with nonEntityTokens will not update the currentSegment's +// tokenCount, as this field will not be used going forward. +func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + ctx := context.Background() + + a := core.activityLog + a.SetStandbyEnable(ctx, true) + a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment + + tokenPath := fmt.Sprintf("%sdirecttokens/%d/0", ActivityLogPrefix, a.GetStartTimestamp()) + clientPath := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", a.GetStartTimestamp()) + // Create some entries without entityIDs + tokenEntryOne := logical.TokenEntry{NamespaceID: "ns1_id", Policies: []string{"hi"}} + entityEntry := logical.TokenEntry{EntityID: "foo", NamespaceID: "ns1_id", Policies: []string{"hi"}} + + id, _ := a.CreateClientID(&tokenEntryOne) + + for i := 0; i < 3; i++ { + a.HandleTokenUsage(&tokenEntryOne) + } + for i := 0; i < 2; i++ { + a.HandleTokenUsage(&entityEntry) + } + err := a.saveCurrentSegmentToStorage(ctx, false) + if err != nil { + t.Fatalf("got error writing TWEs to storage: %v", err) + } + + // Assert that new elements have been written to the fragment + if a.fragment != nil { + t.Errorf("fragment was not reset after write to storage") + } + + // Assert that no tokens have been written to the fragment + readSegmentFromStorageNil(t, core, tokenPath) + + e := readSegmentFromStorage(t, core, clientPath) + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(e.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + if len(out.Clients) != 2 { + t.Fatalf("added 3 distinct TWEs and 2 distinct entity tokens that should all result in the same ID, got: %d", len(out.Clients)) + } + nonEntityTokenFlag := false + entityTokenFlag := false + for _, client := range out.Clients { + if client.NonEntity == true { + nonEntityTokenFlag = true + if client.ClientID != id { + t.Fatalf("expected a client ID of %s, but saved instead %s", id, client.ClientID) + } + } + if client.NonEntity == false { + entityTokenFlag = true + if client.ClientID != "foo" { + t.Fatalf("expected a client ID of %s, but saved instead %s", "foo", client.ClientID) + } + } + } + + if !nonEntityTokenFlag || !entityTokenFlag { + t.Fatalf("Saved clients missing TWE: %v; saved clients missing entity token: %v", nonEntityTokenFlag, entityTokenFlag) + } +} + func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { core, _, _ := TestCoreUnsealed(t) ctx := context.Background() @@ -392,12 +474,12 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { entityRecords := []*activity.EntityRecord{ { - EntityID: ids[0], + ClientID: ids[0], NamespaceID: "root", Timestamp: time.Now().Unix(), }, { - EntityID: ids[1], + ClientID: ids[1], NamespaceID: "root", Timestamp: time.Now().Unix(), }, @@ -405,7 +487,7 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { fragment := &activity.LogFragment{ OriginatingNode: "test-123", - Entities: entityRecords, + Clients: entityRecords, NonEntityTokens: make(map[string]uint64), } @@ -475,7 +557,7 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog - // enabled check is now inside AddEntityToFragment + // enabled check is now inside AddClientToFragment a.SetEnable(true) a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment @@ -522,8 +604,8 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { if err != nil { t.Fatalf("could not unmarshal protobuf: %v", err) } - if len(entityLog0.Entities) != 7000 { - t.Fatalf("unexpected entity length. Expected %d, got %d", 7000, len(entityLog0.Entities)) + if len(entityLog0.Clients) != 7000 { + t.Fatalf("unexpected entity length. Expected %d, got %d", 7000, len(entityLog0.Clients)) } // 7000 more local entities @@ -539,12 +621,12 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { } fragment1 := &activity.LogFragment{ OriginatingNode: "test-123", - Entities: make([]*activity.EntityRecord, 0, 100), + Clients: make([]*activity.EntityRecord, 0, 100), NonEntityTokens: tokens1, } for i := 7000; i < 7100; i++ { - fragment1.Entities = append(fragment1.Entities, &activity.EntityRecord{ - EntityID: genID(i), + fragment1.Clients = append(fragment1.Clients, &activity.EntityRecord{ + ClientID: genID(i), NamespaceID: "root", Timestamp: ts, }) @@ -558,12 +640,12 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { } fragment2 := &activity.LogFragment{ OriginatingNode: "test-123", - Entities: make([]*activity.EntityRecord, 0, 100), + Clients: make([]*activity.EntityRecord, 0, 100), NonEntityTokens: tokens2, } for i := 14000; i < 14100; i++ { - fragment2.Entities = append(fragment2.Entities, &activity.EntityRecord{ - EntityID: genID(i), + fragment2.Clients = append(fragment2.Clients, &activity.EntityRecord{ + ClientID: genID(i), NamespaceID: "root", Timestamp: ts, }) @@ -592,9 +674,9 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { if err != nil { t.Fatalf("could not unmarshal protobuf: %v", err) } - if len(entityLog0.Entities) != activitySegmentEntityCapacity { + if len(entityLog0.Clients) != activitySegmentEntityCapacity { t.Fatalf("unexpected entity length. Expected %d, got %d", activitySegmentEntityCapacity, - len(entityLog0.Entities)) + len(entityLog0.Clients)) } protoSegment1 := readSegmentFromStorage(t, core, path1) @@ -604,17 +686,17 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { t.Fatalf("could not unmarshal protobuf: %v", err) } expectedCount := 14100 - activitySegmentEntityCapacity - if len(entityLog1.Entities) != expectedCount { + if len(entityLog1.Clients) != expectedCount { t.Fatalf("unexpected entity length. Expected %d, got %d", expectedCount, - len(entityLog1.Entities)) + len(entityLog1.Clients)) } entityPresent := make(map[string]struct{}) - for _, e := range entityLog0.Entities { - entityPresent[e.EntityID] = struct{}{} + for _, e := range entityLog0.Clients { + entityPresent[e.ClientID] = struct{}{} } - for _, e := range entityLog1.Entities { - entityPresent[e.EntityID] = struct{}{} + for _, e := range entityLog1.Clients { + entityPresent[e.ClientID] = struct{}{} } for i := 0; i < 14100; i++ { expectedID := genID(i) @@ -622,7 +704,6 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { t.Fatalf("entity ID %v = %v not present", i, expectedID) } } - expectedNSCounts := map[string]uint64{ "root": 9, "aaaaa": 11, @@ -958,7 +1039,7 @@ func entityRecordsEqual(t *testing.T, record1, record2 []*activity.EntityRecord) return false } - idComp := strings.Compare(ei.EntityID, ej.EntityID) + idComp := strings.Compare(ei.ClientID, ej.ClientID) if idComp == -1 { return true } @@ -983,7 +1064,7 @@ func entityRecordsEqual(t *testing.T, record1, record2 []*activity.EntityRecord) for i, a := range entitiesCopy1 { b := entitiesCopy2[i] - if a.EntityID != b.EntityID || a.NamespaceID != b.NamespaceID || a.Timestamp != b.Timestamp { + if a.ClientID != b.ClientID || a.NamespaceID != b.NamespaceID || a.Timestamp != b.Timestamp { return false } } @@ -1000,49 +1081,50 @@ func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) { defer a.fragmentLock.Unlock() a.currentSegment = segmentInfo{ startTimestamp: time.Time{}.Unix(), - currentEntities: &activity.EntityActivityLog{ - Entities: make([]*activity.EntityRecord, 0), + currentClients: &activity.EntityActivityLog{ + Clients: make([]*activity.EntityRecord, 0), }, tokenCount: a.currentSegment.tokenCount, - entitySequenceNumber: 0, + clientSequenceNumber: 0, } - a.entityTracker.activeEntities = make(map[string]struct{}) + a.clientTracker.activeClients = make(map[string]struct{}) } -func TestActivityLog_loadCurrentEntitySegment(t *testing.T) { +func TestActivityLog_loadCurrentClientSegment(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog - - // we must verify that loadCurrentEntitySegment doesn't overwrite the in-memory token count + // we must verify that loadCurrentClientSegment doesn't overwrite the in-memory token count tokenRecords := make(map[string]uint64) tokenRecords["test"] = 1 tokenCount := &activity.TokenCount{ CountByNamespaceID: tokenRecords, } - a.SetTokenCount(tokenCount) + a.l.Lock() + a.currentSegment.tokenCount = tokenCount + a.l.Unlock() // setup in-storage data to load for testing entityRecords := []*activity.EntityRecord{ { - EntityID: "11111111-1111-1111-1111-111111111111", + ClientID: "11111111-1111-1111-1111-111111111111", NamespaceID: "root", Timestamp: time.Now().Unix(), }, { - EntityID: "22222222-2222-2222-2222-222222222222", + ClientID: "22222222-2222-2222-2222-222222222222", NamespaceID: "root", Timestamp: time.Now().Unix(), }, } testEntities1 := &activity.EntityActivityLog{ - Entities: entityRecords[:1], + Clients: entityRecords[:1], } testEntities2 := &activity.EntityActivityLog{ - Entities: entityRecords[1:2], + Clients: entityRecords[1:2], } testEntities3 := &activity.EntityActivityLog{ - Entities: entityRecords[:2], + Clients: entityRecords[:2], } time1 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC).Unix() @@ -1084,11 +1166,19 @@ func TestActivityLog_loadCurrentEntitySegment(t *testing.T) { ctx := context.Background() for _, tc := range testCases { - err := a.loadCurrentEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum) + a.l.Lock() + a.fragmentLock.Lock() + // loadCurrentClientSegment requires us to grab the fragment lock and the + // activityLog lock, as per the comment in the loadCurrentClientSegment + // function + err := a.loadCurrentClientSegment(ctx, time.Unix(tc.time, 0), tc.seqNum) + a.fragmentLock.Unlock() + a.l.Unlock() + if err != nil { t.Fatalf("got error loading data for %q: %v", tc.path, err) } - if !reflect.DeepEqual(a.GetCountByNamespaceID(), tokenCount.CountByNamespaceID) { + if !reflect.DeepEqual(a.GetStoredTokenCountByNamespaceID(), tokenCount.CountByNamespaceID) { t.Errorf("this function should not wipe out the in-memory token count") } @@ -1104,13 +1194,13 @@ func TestActivityLog_loadCurrentEntitySegment(t *testing.T) { } currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Entities, tc.entities.Entities) { - t.Errorf("bad data loaded. expected: %v, got: %v for path %q", tc.entities.Entities, currentEntities, tc.path) + if !entityRecordsEqual(t, currentEntities.Clients, tc.entities.Clients) { + t.Errorf("bad data loaded. expected: %v, got: %v for path %q", tc.entities.Clients, currentEntities, tc.path) } - activeEntities := core.GetActiveEntities() - if !ActiveEntitiesEqual(activeEntities, tc.entities.Entities) { - t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v for path %q", tc.entities.Entities, activeEntities, tc.path) + activeClients := core.GetActiveClients() + if !ActiveEntitiesEqual(activeClients, tc.entities.Clients) { + t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v for path %q", tc.entities.Clients, activeClients, tc.path) } a.resetEntitiesInMemory(t) @@ -1125,21 +1215,21 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { // setup in-storage data to load for testing entityRecords := []*activity.EntityRecord{ { - EntityID: "11111111-1111-1111-1111-111111111111", + ClientID: "11111111-1111-1111-1111-111111111111", NamespaceID: "root", Timestamp: time.Now().Unix(), }, { - EntityID: "22222222-2222-2222-2222-222222222222", + ClientID: "22222222-2222-2222-2222-222222222222", NamespaceID: "root", Timestamp: time.Now().Unix(), }, } testEntities1 := &activity.EntityActivityLog{ - Entities: entityRecords[:1], + Clients: entityRecords[:1], } testEntities2 := &activity.EntityActivityLog{ - Entities: entityRecords[:2], + Clients: entityRecords[:2], } time1 := time.Date(2020, 4, 1, 0, 0, 0, 0, time.UTC).Unix() @@ -1190,7 +1280,7 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { if tc.refresh { a.l.Lock() a.fragmentLock.Lock() - a.entityTracker.activeEntities = make(map[string]struct{}) + a.clientTracker.activeClients = make(map[string]struct{}) a.currentSegment.startTimestamp = tc.time a.fragmentLock.Unlock() a.l.Unlock() @@ -1201,13 +1291,15 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { t.Fatalf("got error loading data for %q: %v", tc.path, err) } - activeEntities := core.GetActiveEntities() - if !ActiveEntitiesEqual(activeEntities, tc.entities.Entities) { - t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v for path %q", tc.entities.Entities, activeEntities, tc.path) + activeClients := core.GetActiveClients() + if !ActiveEntitiesEqual(activeClients, tc.entities.Clients) { + t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v for path %q", tc.entities.Clients, activeClients, tc.path) } } } +// TestActivityLog_loadTokenCount ensures that previous segments with tokenCounts +// can still be read from storage, even when TWE's have distinct, tracked clientIDs. func TestActivityLog_loadTokenCount(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog @@ -1241,19 +1333,18 @@ func TestActivityLog_loadTokenCount(t *testing.T) { }, } + ctx := context.Background() for _, tc := range testCases { WriteToStorage(t, core, ActivityLogPrefix+tc.path, data) } - ctx := context.Background() for _, tc := range testCases { - // a.currentSegment.tokenCount doesn't need to be wiped each iter since it happens in loadTokenSegment() err := a.loadTokenCount(ctx, time.Unix(tc.time, 0)) if err != nil { t.Fatalf("got error loading data for %q: %v", tc.path, err) } - nsCount := a.GetCountByNamespaceID() + nsCount := a.GetStoredTokenCountByNamespaceID() if !reflect.DeepEqual(nsCount, tokenRecords) { t.Errorf("bad token count loaded. expected: %v got: %v for path %q", tokenRecords, nsCount, tc.path) } @@ -1300,7 +1391,7 @@ func TestActivityLog_StopAndRestart(t *testing.T) { wg.Wait() a = core.activityLog - if a.GetCountByNamespaceID() == nil { + if a.GetStoredTokenCountByNamespaceID() == nil { t.Fatalf("nil token count map") } @@ -1327,17 +1418,17 @@ func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities if includeEntities { entityRecords = []*activity.EntityRecord{ { - EntityID: "11111111-1111-1111-1111-111111111111", + ClientID: "11111111-1111-1111-1111-111111111111", NamespaceID: namespace.RootNamespaceID, Timestamp: time.Now().Unix(), }, { - EntityID: "22222222-2222-2222-2222-222222222222", + ClientID: "22222222-2222-2222-2222-222222222222", NamespaceID: namespace.RootNamespaceID, Timestamp: time.Now().Unix(), }, { - EntityID: "33333333-2222-2222-2222-222222222222", + ClientID: "33333333-2222-2222-2222-222222222222", NamespaceID: namespace.RootNamespaceID, Timestamp: time.Now().Unix(), }, @@ -1345,7 +1436,7 @@ func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities if constants.IsEnterprise { entityRecords = append(entityRecords, []*activity.EntityRecord{ { - EntityID: "44444444-1111-1111-1111-111111111111", + ClientID: "44444444-1111-1111-1111-111111111111", NamespaceID: "ns1", Timestamp: time.Now().Unix(), }, @@ -1353,7 +1444,7 @@ func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities } for i, entityRecord := range entityRecords { entityData, err := proto.Marshal(&activity.EntityActivityLog{ - Entities: []*activity.EntityRecord{entityRecord}, + Clients: []*activity.EntityRecord{entityRecord}, }) if err != nil { t.Fatalf(err.Error()) @@ -1393,7 +1484,7 @@ func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities } func TestActivityLog_refreshFromStoredLog(t *testing.T) { - a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true) + a, expectedClientRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true) a.SetEnable(true) var wg sync.WaitGroup @@ -1404,33 +1495,92 @@ func TestActivityLog_refreshFromStoredLog(t *testing.T) { wg.Wait() expectedActive := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[1:], + Clients: expectedClientRecords[1:], } expectedCurrent := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[len(expectedEntityRecords)-1:], + Clients: expectedClientRecords[len(expectedClientRecords)-1:], } currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Entities, expectedCurrent.Entities) { + if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { // we only expect the newest entity segment to be loaded (for the current month) t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) } - nsCount := a.GetCountByNamespaceID() + nsCount := a.GetStoredTokenCountByNamespaceID() if !reflect.DeepEqual(nsCount, expectedTokenCounts) { // we expect all token counts to be loaded t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount) } - activeEntities := a.core.GetActiveEntities() - if !ActiveEntitiesEqual(activeEntities, expectedActive.Entities) { - // we expect activeEntities to be loaded for the entire month - t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Entities, activeEntities) + activeClients := a.core.GetActiveClients() + if !ActiveEntitiesEqual(activeClients, expectedActive.Clients) { + // we expect activeClients to be loaded for the entire month + t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Clients, activeClients) + } +} + +// TestCreateClientID verifies that CreateClientID uses the entity ID for a token +// entry if one exists, and creates an appropriate client ID otherwise. +func TestCreateClientID(t *testing.T) { + entry := logical.TokenEntry{NamespaceID: "namespaceFoo", Policies: []string{"bar", "baz", "foo", "banana"}} + activityLog := ActivityLog{} + id, isTWE := activityLog.CreateClientID(&entry) + if !isTWE { + t.Fatalf("TWE token should return true value in isTWE bool") + } + expectedIDPlaintext := "banana" + string(sortedPoliciesTWEDelimiter) + "bar" + + string(sortedPoliciesTWEDelimiter) + "baz" + + string(sortedPoliciesTWEDelimiter) + "foo" + string(clientIDTWEDelimiter) + "namespaceFoo" + + hashed := sha256.Sum256([]byte(expectedIDPlaintext)) + expectedID := base64.URLEncoding.EncodeToString(hashed[:]) + if expectedID != id { + t.Fatalf("wrong ID: expected %s, found %s", expectedID, id) + } + // Test with entityID + entry = logical.TokenEntry{EntityID: "entityFoo", NamespaceID: "namespaceFoo", Policies: []string{"bar", "baz", "foo", "banana"}} + id, isTWE = activityLog.CreateClientID(&entry) + if isTWE { + t.Fatalf("token with entity should return false value in isTWE bool") + } + if id != "entityFoo" { + t.Fatalf("client ID should be entity ID") + } + + // Test without namespace + entry = logical.TokenEntry{Policies: []string{"bar", "baz", "foo", "banana"}} + id, isTWE = activityLog.CreateClientID(&entry) + if !isTWE { + t.Fatalf("TWE token should return true value in isTWE bool") + } + expectedIDPlaintext = "banana" + string(sortedPoliciesTWEDelimiter) + "bar" + + string(sortedPoliciesTWEDelimiter) + "baz" + + string(sortedPoliciesTWEDelimiter) + "foo" + string(clientIDTWEDelimiter) + + hashed = sha256.Sum256([]byte(expectedIDPlaintext)) + expectedID = base64.URLEncoding.EncodeToString(hashed[:]) + if expectedID != id { + t.Fatalf("wrong ID: expected %s, found %s", expectedID, id) + } + + // Test without policies + entry = logical.TokenEntry{NamespaceID: "namespaceFoo"} + id, isTWE = activityLog.CreateClientID(&entry) + if !isTWE { + t.Fatalf("TWE token should return true value in isTWE bool") + } + expectedIDPlaintext = "namespaceFoo" + + hashed = sha256.Sum256([]byte(expectedIDPlaintext)) + expectedID = base64.URLEncoding.EncodeToString(hashed[:]) + if expectedID != id { + t.Fatalf("wrong ID: expected %s, found %s", expectedID, id) } } func TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled(t *testing.T) { - a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true) + a, expectedClientRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, time.Now().UTC(), true, true) a.SetEnable(true) var wg sync.WaitGroup @@ -1448,25 +1598,25 @@ func TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled(t *testi wg.Wait() expected := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[len(expectedEntityRecords)-1:], + Clients: expectedClientRecords[len(expectedClientRecords)-1:], } currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Entities, expected.Entities) { + if !entityRecordsEqual(t, currentEntities.Clients, expected.Clients) { // we only expect the newest entity segment to be loaded (for the current month) t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expected, currentEntities) } - nsCount := a.GetCountByNamespaceID() + nsCount := a.GetStoredTokenCountByNamespaceID() if !reflect.DeepEqual(nsCount, expectedTokenCounts) { // we expect all token counts to be loaded t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount) } - activeEntities := a.core.GetActiveEntities() - if !ActiveEntitiesEqual(activeEntities, expected.Entities) { - // we only expect activeEntities to be loaded for the newest segment (for the current month) - t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expected.Entities, activeEntities) + activeClients := a.core.GetActiveClients() + if !ActiveEntitiesEqual(activeClients, expected.Clients) { + // we only expect activeClients to be loaded for the newest segment (for the current month) + t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expected.Clients, activeClients) } } @@ -1484,7 +1634,7 @@ func TestActivityLog_refreshFromStoredLogContextCancelled(t *testing.T) { } func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { - a, expectedEntityRecords, _ := setupActivityRecordsInStorage(t, time.Now().UTC(), true, false) + a, expectedClientRecords, _ := setupActivityRecordsInStorage(t, time.Now().UTC(), true, false) a.SetEnable(true) var wg sync.WaitGroup @@ -1495,24 +1645,24 @@ func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { wg.Wait() expectedActive := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[1:], + Clients: expectedClientRecords[1:], } expectedCurrent := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[len(expectedEntityRecords)-1:], + Clients: expectedClientRecords[len(expectedClientRecords)-1:], } currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Entities, expectedCurrent.Entities) { + if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { // we expect all segments for the current month to be loaded t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) } - activeEntities := a.core.GetActiveEntities() - if !ActiveEntitiesEqual(activeEntities, expectedActive.Entities) { - t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Entities, activeEntities) + activeClients := a.core.GetActiveClients() + if !ActiveEntitiesEqual(activeClients, expectedActive.Clients) { + t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Clients, activeClients) } // we expect no tokens - nsCount := a.GetCountByNamespaceID() + nsCount := a.GetStoredTokenCountByNamespaceID() if len(nsCount) > 0 { t.Errorf("expected no token counts to be loaded. got: %v", nsCount) } @@ -1529,19 +1679,19 @@ func TestActivityLog_refreshFromStoredLogNoEntities(t *testing.T) { } wg.Wait() - nsCount := a.GetCountByNamespaceID() + nsCount := a.GetStoredTokenCountByNamespaceID() if !reflect.DeepEqual(nsCount, expectedTokenCounts) { // we expect all token counts to be loaded t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount) } currentEntities := a.GetCurrentEntities() - if len(currentEntities.Entities) > 0 { + if len(currentEntities.Clients) > 0 { t.Errorf("expected no current entity segment to be loaded. got: %v", currentEntities) } - activeEntities := a.core.GetActiveEntities() - if len(activeEntities) > 0 { - t.Errorf("expected no active entity segment to be loaded. got: %v", activeEntities) + activeClients := a.core.GetActiveClients() + if len(activeClients) > 0 { + t.Errorf("expected no active entity segment to be loaded. got: %v", activeClients) } } @@ -1583,7 +1733,7 @@ func TestActivityLog_refreshFromStoredLogPreviousMonth(t *testing.T) { // can handle end of month rotations monthStart := timeutil.StartOfMonth(time.Now().UTC()) oneMonthAgoStart := timeutil.StartOfPreviousMonth(monthStart) - a, expectedEntityRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, oneMonthAgoStart, true, true) + a, expectedClientRecords, expectedTokenCounts := setupActivityRecordsInStorage(t, oneMonthAgoStart, true, true) a.SetEnable(true) var wg sync.WaitGroup @@ -1594,28 +1744,28 @@ func TestActivityLog_refreshFromStoredLogPreviousMonth(t *testing.T) { wg.Wait() expectedActive := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[1:], + Clients: expectedClientRecords[1:], } expectedCurrent := &activity.EntityActivityLog{ - Entities: expectedEntityRecords[len(expectedEntityRecords)-1:], + Clients: expectedClientRecords[len(expectedClientRecords)-1:], } currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Entities, expectedCurrent.Entities) { + if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { // we only expect the newest entity segment to be loaded (for the current month) t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) } - nsCount := a.GetCountByNamespaceID() + nsCount := a.GetStoredTokenCountByNamespaceID() if !reflect.DeepEqual(nsCount, expectedTokenCounts) { // we expect all token counts to be loaded t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount) } - activeEntities := a.core.GetActiveEntities() - if !ActiveEntitiesEqual(activeEntities, expectedActive.Entities) { - // we expect activeEntities to be loaded for the entire month - t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Entities, activeEntities) + activeClients := a.core.GetActiveClients() + if !ActiveEntitiesEqual(activeClients, expectedActive.Clients) { + // we expect activeClients to be loaded for the entire month + t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v", expectedActive.Clients, activeClients) } } @@ -1911,6 +2061,404 @@ func TestActivityLog_EndOfMonth(t *testing.T) { } } +// TestActivityLog_CalculatePrecomputedQueriesWithMixedTWEs tests that precomputed +// queries work when new months have tokens without entities saved in the TokenCount, +// as clients, or both. +func TestActivityLog_CalculatePrecomputedQueriesWithMixedTWEs(t *testing.T) { + timeutil.SkipAtEndOfMonth(t) + + // root namespace will have TWEs with clientIDs and untracked TWEs + // ns1 namespace will only have TWEs with clientIDs + // aaaa, bbbb, and cccc namespace will only have untracked TWEs + // 1. January tests clientIDs from a segment don't roll over into another month's + // client counts in same segment. + // 2. August tests that client counts work when split across segment. + // 3. September tests that an entire segment in a month yields correct cc. + // 4. October tests that a month with only a segment rolled over from previous + // month yields correct client count. + + january := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + august := time.Date(2020, 8, 15, 12, 0, 0, 0, time.UTC) + september := timeutil.StartOfMonth(time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC)) + october := timeutil.StartOfMonth(time.Date(2020, 10, 1, 0, 0, 0, 0, time.UTC)) + november := timeutil.StartOfMonth(time.Date(2020, 11, 1, 0, 0, 0, 0, time.UTC)) + + core, _, _, sink := TestCoreUnsealedWithMetrics(t) + a := core.activityLog + ctx := namespace.RootContext(nil) + + // Generate overlapping sets of entity IDs from this list. + + clientRecords := make([]*activity.EntityRecord, 45) + clientNamespaces := []string{"root", "aaaaa", "bbbbb", "root", "root", "ccccc", "root", "bbbbb", "rrrrr"} + + for i := range clientRecords { + clientRecords[i] = &activity.EntityRecord{ + ClientID: fmt.Sprintf("111122222-3333-4444-5555-%012v", i), + NamespaceID: clientNamespaces[i/5], + Timestamp: time.Now().Unix(), + NonEntity: true, + } + } + + toInsert := []struct { + StartTime int64 + Segment uint64 + Clients []*activity.EntityRecord + }{ + // January, should not be included + { + january.Unix(), + 0, + clientRecords[40:45], + }, + { + august.Unix(), + 0, + clientRecords[:13], + }, + { + august.Unix(), + 1, + clientRecords[13:20], + }, + { + september.Unix(), + 1, + clientRecords[10:30], + }, + { + september.Unix(), + 2, + clientRecords[15:40], + }, + { + september.Unix(), + 3, + clientRecords[15:40], + }, + { + october.Unix(), + 3, + clientRecords[17:23], + }, + } + + // Insert token counts for all 3 segments + toInsertTokenCount := []struct { + StartTime int64 + Segment uint64 + CountByNamespaceID map[string]uint64 + }{ + { + january.Unix(), + 0, + map[string]uint64{"root": 3, "ns1": 5}, + }, + { + august.Unix(), + 0, + map[string]uint64{"root": 40, "ns1": 50}, + }, + { + august.Unix(), + 1, + map[string]uint64{"root": 60, "ns1": 70}, + }, + { + september.Unix(), + 1, + map[string]uint64{"root": 400, "ns1": 500}, + }, + { + september.Unix(), + 2, + map[string]uint64{"root": 700, "ns1": 800}, + }, + { + september.Unix(), + 3, + map[string]uint64{"root": 0, "ns1": 0}, + }, + { + october.Unix(), + 3, + map[string]uint64{"root": 0, "ns1": 0}, + }, + } + doInsertTokens := func(i int) { + segment := toInsertTokenCount[i] + tc := &activity.TokenCount{ + CountByNamespaceID: segment.CountByNamespaceID, + } + data, err := proto.Marshal(tc) + if err != nil { + t.Fatal(err) + } + tokenPath := fmt.Sprintf("%vdirecttokens/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + WriteToStorage(t, core, tokenPath, data) + } + + // Note that precomputedQuery worker doesn't filter + // for times <= the one it was asked to do. Is that a problem? + // Here, it means that we can't insert everything *first* and do multiple + // test cases, we have to write logs incrementally. + doInsert := func(i int) { + segment := toInsert[i] + eal := &activity.EntityActivityLog{ + Clients: segment.Clients, + } + data, err := proto.Marshal(eal) + if err != nil { + t.Fatal(err) + } + path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + WriteToStorage(t, core, path, data) + } + expectedCounts := []struct { + StartTime time.Time + EndTime time.Time + ByNamespace map[string]int + }{ + // First test case + { + august, + timeutil.EndOfMonth(august), + map[string]int{ + "root": 110, // 10 TWEs + 50 + 60 direct tokens + "ns1": 120, // 60 + 70 direct tokens + "aaaaa": 5, + "bbbbb": 5, + }, + }, + // Second test case + { + august, + timeutil.EndOfMonth(september), + map[string]int{ + "root": 1220, // 110 from august + 10 non-overlapping TWEs in September, + 400 + 700 direct tokens in september + "ns1": 1420, // 120 from August + 500 + 800 direct tokens in september + "aaaaa": 5, + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + september, + timeutil.EndOfMonth(september), + map[string]int{ + "root": 1115, // 15 TWEs in September, + 400 + 700 direct tokens + "ns1": 1300, // 500 direct tokens in september + "bbbbb": 10, + "ccccc": 5, + }, + }, + // Third test case + { + august, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 1220, // 1220 from Aug to Sept + "ns1": 1420, // 1420 from Aug to Sept + "aaaaa": 5, + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + september, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 1115, // 1115 in Sept + "ns1": 1300, // 1300 in Sept + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + october, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 6, // 6 overlapping TWEs in October + "ns1": 0, // No new direct tokens in october + }, + }, + } + + checkPrecomputedQuery := func(i int) { + t.Helper() + pq, err := a.queryStore.Get(ctx, expectedCounts[i].StartTime, expectedCounts[i].EndTime) + if err != nil { + t.Fatal(err) + } + if pq == nil { + t.Errorf("empty result for %v -- %v", expectedCounts[i].StartTime, expectedCounts[i].EndTime) + } + if len(pq.Namespaces) != len(expectedCounts[i].ByNamespace) { + t.Errorf("mismatched number of namespaces, expected %v got %v", + len(expectedCounts[i].ByNamespace), len(pq.Namespaces)) + } + for _, nsRecord := range pq.Namespaces { + val, ok := expectedCounts[i].ByNamespace[nsRecord.NamespaceID] + if !ok { + t.Errorf("unexpected namespace %v", nsRecord.NamespaceID) + continue + } + if uint64(val) != nsRecord.NonEntityTokens { + t.Errorf("wrong number of entities in %v: expected %v, got %v", + nsRecord.NamespaceID, val, nsRecord.NonEntityTokens) + } + } + if !pq.StartTime.Equal(expectedCounts[i].StartTime) { + t.Errorf("mismatched start time: expected %v got %v", + expectedCounts[i].StartTime, pq.StartTime) + } + if !pq.EndTime.Equal(expectedCounts[i].EndTime) { + t.Errorf("mismatched end time: expected %v got %v", + expectedCounts[i].EndTime, pq.EndTime) + } + } + + testCases := []struct { + InsertUpTo int // index in the toInsert array + PrevMonth int64 + NextMonth int64 + ExpectedUpTo int // index in the expectedCounts array + }{ + { + 2, // jan-august + august.Unix(), + september.Unix(), + 0, // august-august + }, + { + 5, // jan-sept + september.Unix(), + october.Unix(), + 2, // august-september + }, + { + 6, // jan-oct + october.Unix(), + november.Unix(), + 5, // august-october + }, + } + + inserted := -1 + for _, tc := range testCases { + t.Logf("tc %+v", tc) + + // Persists across loops + for inserted < tc.InsertUpTo { + inserted += 1 + t.Logf("inserting segment %v", inserted) + doInsert(inserted) + doInsertTokens(inserted) + } + + intent := &ActivityIntentLog{ + PreviousMonth: tc.PrevMonth, + NextMonth: tc.NextMonth, + } + data, err := json.Marshal(intent) + if err != nil { + t.Fatal(err) + } + WriteToStorage(t, core, "sys/counters/activity/endofmonth", data) + + // Pretend we've successfully rolled over to the following month + a.SetStartTimestamp(tc.NextMonth) + + err = a.precomputedQueryWorker() + if err != nil { + t.Fatal(err) + } + + expectMissingSegment(t, core, "sys/counters/activity/endofmonth") + + for i := 0; i <= tc.ExpectedUpTo; i++ { + checkPrecomputedQuery(i) + } + } + + // Check metrics on the last precomputed query + // (otherwise we need a way to reset the in-memory metrics between test cases.) + + intervals := sink.Data() + // Test crossed an interval boundary, don't try to deal with it. + if len(intervals) > 1 { + t.Skip("Detected interval crossing.") + } + expectedGauges := []struct { + Name string + NamespaceLabel string + Value float32 + }{ + // october values + { + "identity.nonentity.active.monthly", + "root", + 6.0, + }, + { + "identity.nonentity.active.monthly", + "deleted-bbbbb", // No namespace entry for this fake ID + 10.0, + }, + { + "identity.nonentity.active.monthly", + "deleted-ccccc", + 5.0, + }, + // august-september values + { + "identity.nonentity.active.reporting_period", + "root", + 1220.0, + }, + { + "identity.nonentity.active.reporting_period", + "deleted-aaaaa", + 5.0, + }, + { + "identity.nonentity.active.reporting_period", + "deleted-bbbbb", + 10.0, + }, + { + "identity.nonentity.active.reporting_period", + "deleted-ccccc", + 5.0, + }, + } + for _, g := range expectedGauges { + found := false + for _, actual := range intervals[0].Gauges { + actualNamespaceLabel := "" + for _, l := range actual.Labels { + if l.Name == "namespace" { + actualNamespaceLabel = l.Value + break + } + } + if actual.Name == g.Name && actualNamespaceLabel == g.NamespaceLabel { + found = true + if actual.Value != g.Value { + t.Errorf("Mismatched value for %v %v %v != %v", + g.Name, g.NamespaceLabel, actual.Value, g.Value) + } + break + } + } + if !found { + t.Errorf("No guage found for %v %v", + g.Name, g.NamespaceLabel) + } + } +} + func TestActivityLog_SaveAfterDisable(t *testing.T) { core, _, _ := TestCoreUnsealed(t) ctx := namespace.RootContext(nil) @@ -1976,7 +2524,7 @@ func TestActivityLog_Precompute(t *testing.T) { for i := range entityRecords { entityRecords[i] = &activity.EntityRecord{ - EntityID: fmt.Sprintf("111122222-3333-4444-5555-%012v", i), + ClientID: fmt.Sprintf("111122222-3333-4444-5555-%012v", i), NamespaceID: entityNamespaces[i/5], Timestamp: time.Now().Unix(), } @@ -1985,7 +2533,7 @@ func TestActivityLog_Precompute(t *testing.T) { toInsert := []struct { StartTime int64 Segment uint64 - Entities []*activity.EntityRecord + Clients []*activity.EntityRecord }{ // January, should not be included { @@ -2033,7 +2581,7 @@ func TestActivityLog_Precompute(t *testing.T) { doInsert := func(i int) { segment := toInsert[i] eal := &activity.EntityActivityLog{ - Entities: segment.Entities, + Clients: segment.Clients, } data, err := proto.Marshal(eal) if err != nil { @@ -2282,6 +2830,341 @@ func TestActivityLog_Precompute(t *testing.T) { } } +//TestActivityLog_PrecomputeNonEntityTokensWithID is the same test as +// TestActivityLog_Precompute, except all the clients are tokens without +// entities. This ensures the deduplication logic and separation logic between +// entities and TWE clients is correct. +func TestActivityLog_PrecomputeNonEntityTokensWithID(t *testing.T) { + timeutil.SkipAtEndOfMonth(t) + + january := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + august := time.Date(2020, 8, 15, 12, 0, 0, 0, time.UTC) + september := timeutil.StartOfMonth(time.Date(2020, 9, 1, 0, 0, 0, 0, time.UTC)) + october := timeutil.StartOfMonth(time.Date(2020, 10, 1, 0, 0, 0, 0, time.UTC)) + november := timeutil.StartOfMonth(time.Date(2020, 11, 1, 0, 0, 0, 0, time.UTC)) + + core, _, _, sink := TestCoreUnsealedWithMetrics(t) + a := core.activityLog + ctx := namespace.RootContext(nil) + + // Generate overlapping sets of entity IDs from this list. + // january: 40-44 RRRRR + // first month: 0-19 RRRRRAAAAABBBBBRRRRR + // second month: 10-29 BBBBBRRRRRRRRRRCCCCC + // third month: 15-39 RRRRRRRRRRCCCCCRRRRRBBBBB + + clientRecords := make([]*activity.EntityRecord, 45) + clientNamespaces := []string{"root", "aaaaa", "bbbbb", "root", "root", "ccccc", "root", "bbbbb", "rrrrr"} + + for i := range clientRecords { + clientRecords[i] = &activity.EntityRecord{ + ClientID: fmt.Sprintf("111122222-3333-4444-5555-%012v", i), + NamespaceID: clientNamespaces[i/5], + Timestamp: time.Now().Unix(), + NonEntity: true, + } + } + + toInsert := []struct { + StartTime int64 + Segment uint64 + Clients []*activity.EntityRecord + }{ + // January, should not be included + { + january.Unix(), + 0, + clientRecords[40:45], + }, + // Artifically split August and October + { // 1 + august.Unix(), + 0, + clientRecords[:13], + }, + { // 2 + august.Unix(), + 1, + clientRecords[13:20], + }, + { // 3 + september.Unix(), + 0, + clientRecords[10:30], + }, + { // 4 + october.Unix(), + 0, + clientRecords[15:40], + }, + { + october.Unix(), + 1, + clientRecords[15:40], + }, + { + october.Unix(), + 2, + clientRecords[17:23], + }, + } + + // Note that precomputedQuery worker doesn't filter + // for times <= the one it was asked to do. Is that a problem? + // Here, it means that we can't insert everything *first* and do multiple + // test cases, we have to write logs incrementally. + doInsert := func(i int) { + segment := toInsert[i] + eal := &activity.EntityActivityLog{ + Clients: segment.Clients, + } + data, err := proto.Marshal(eal) + if err != nil { + t.Fatal(err) + } + path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + WriteToStorage(t, core, path, data) + } + + expectedCounts := []struct { + StartTime time.Time + EndTime time.Time + ByNamespace map[string]int + }{ + // First test case + { + august, + timeutil.EndOfMonth(august), + map[string]int{ + "root": 10, + "aaaaa": 5, + "bbbbb": 5, + }, + }, + // Second test case + { + august, + timeutil.EndOfMonth(september), + map[string]int{ + "root": 15, + "aaaaa": 5, + "bbbbb": 5, + "ccccc": 5, + }, + }, + { + september, + timeutil.EndOfMonth(september), + map[string]int{ + "root": 10, + "bbbbb": 5, + "ccccc": 5, + }, + }, + // Third test case + { + august, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 20, + "aaaaa": 5, + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + september, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 15, + "bbbbb": 10, + "ccccc": 5, + }, + }, + { + october, + timeutil.EndOfMonth(october), + map[string]int{ + "root": 15, + "bbbbb": 5, + "ccccc": 5, + }, + }, + } + + checkPrecomputedQuery := func(i int) { + t.Helper() + pq, err := a.queryStore.Get(ctx, expectedCounts[i].StartTime, expectedCounts[i].EndTime) + if err != nil { + t.Fatal(err) + } + if pq == nil { + t.Errorf("empty result for %v -- %v", expectedCounts[i].StartTime, expectedCounts[i].EndTime) + } + if len(pq.Namespaces) != len(expectedCounts[i].ByNamespace) { + t.Errorf("mismatched number of namespaces, expected %v got %v", + len(expectedCounts[i].ByNamespace), len(pq.Namespaces)) + } + for _, nsRecord := range pq.Namespaces { + val, ok := expectedCounts[i].ByNamespace[nsRecord.NamespaceID] + if !ok { + t.Errorf("unexpected namespace %v", nsRecord.NamespaceID) + continue + } + if uint64(val) != nsRecord.NonEntityTokens { + t.Errorf("wrong number of entities in %v: expected %v, got %v", + nsRecord.NamespaceID, val, nsRecord.NonEntityTokens) + } + } + if !pq.StartTime.Equal(expectedCounts[i].StartTime) { + t.Errorf("mismatched start time: expected %v got %v", + expectedCounts[i].StartTime, pq.StartTime) + } + if !pq.EndTime.Equal(expectedCounts[i].EndTime) { + t.Errorf("mismatched end time: expected %v got %v", + expectedCounts[i].EndTime, pq.EndTime) + } + } + + testCases := []struct { + InsertUpTo int // index in the toInsert array + PrevMonth int64 + NextMonth int64 + ExpectedUpTo int // index in the expectedCounts array + }{ + { + 2, // jan-august + august.Unix(), + september.Unix(), + 0, // august-august + }, + { + 3, // jan-sept + september.Unix(), + october.Unix(), + 2, // august-september + }, + { + 6, // jan-oct + october.Unix(), + november.Unix(), + 5, // august-september + }, + } + + inserted := -1 + for _, tc := range testCases { + t.Logf("tc %+v", tc) + + // Persists across loops + for inserted < tc.InsertUpTo { + inserted += 1 + t.Logf("inserting segment %v", inserted) + doInsert(inserted) + } + + intent := &ActivityIntentLog{ + PreviousMonth: tc.PrevMonth, + NextMonth: tc.NextMonth, + } + data, err := json.Marshal(intent) + if err != nil { + t.Fatal(err) + } + WriteToStorage(t, core, "sys/counters/activity/endofmonth", data) + + // Pretend we've successfully rolled over to the following month + a.SetStartTimestamp(tc.NextMonth) + + err = a.precomputedQueryWorker() + if err != nil { + t.Fatal(err) + } + + expectMissingSegment(t, core, "sys/counters/activity/endofmonth") + + for i := 0; i <= tc.ExpectedUpTo; i++ { + checkPrecomputedQuery(i) + } + } + + // Check metrics on the last precomputed query + // (otherwise we need a way to reset the in-memory metrics between test cases.) + + intervals := sink.Data() + // Test crossed an interval boundary, don't try to deal with it. + if len(intervals) > 1 { + t.Skip("Detected interval crossing.") + } + expectedGauges := []struct { + Name string + NamespaceLabel string + Value float32 + }{ + // october values + { + "identity.nonentity.active.monthly", + "root", + 15.0, + }, + { + "identity.nonentity.active.monthly", + "deleted-bbbbb", // No namespace entry for this fake ID + 5.0, + }, + { + "identity.nonentity.active.monthly", + "deleted-ccccc", + 5.0, + }, + // august-september values + { + "identity.nonentity.active.reporting_period", + "root", + 20.0, + }, + { + "identity.nonentity.active.reporting_period", + "deleted-aaaaa", + 5.0, + }, + { + "identity.nonentity.active.reporting_period", + "deleted-bbbbb", + 10.0, + }, + { + "identity.nonentity.active.reporting_period", + "deleted-ccccc", + 5.0, + }, + } + for _, g := range expectedGauges { + found := false + for _, actual := range intervals[0].Gauges { + actualNamespaceLabel := "" + for _, l := range actual.Labels { + if l.Name == "namespace" { + actualNamespaceLabel = l.Value + break + } + } + if actual.Name == g.Name && actualNamespaceLabel == g.NamespaceLabel { + found = true + if actual.Value != g.Value { + t.Errorf("Mismatched value for %v %v %v != %v", + g.Name, g.NamespaceLabel, actual.Value, g.Value) + } + break + } + } + if !found { + t.Errorf("No guage found for %v %v", + g.Name, g.NamespaceLabel) + } + } +} + type BlockingInmemStorage struct{} func (b *BlockingInmemStorage) List(ctx context.Context, prefix string) ([]string, error) { diff --git a/vault/activity_log_testing_util.go b/vault/activity_log_testing_util.go index 6d3069319..36f97082f 100644 --- a/vault/activity_log_testing_util.go +++ b/vault/activity_log_testing_util.go @@ -2,6 +2,7 @@ package vault import ( "context" + "math/rand" "testing" "github.com/hashicorp/vault/helper/constants" @@ -33,18 +34,18 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) (map[string]uint64, c.activityLog.fragmentLock.Lock() defer c.activityLog.fragmentLock.Unlock() - c.activityLog.currentSegment.tokenCount.CountByNamespaceID = tokens - c.activityLog.entityTracker.entityCountByNamespaceID = entitiesByNS + c.activityLog.clientTracker.nonEntityCountByNamespaceID = tokens + c.activityLog.clientTracker.entityCountByNamespaceID = entitiesByNS return entitiesByNS, tokens } -// Return the in-memory activeEntities from an activity log -func (c *Core) GetActiveEntities() map[string]struct{} { +// Return the in-memory activeClients from an activity log +func (c *Core) GetActiveClients() map[string]struct{} { out := make(map[string]struct{}) c.stateLock.RLock() c.activityLog.fragmentLock.RLock() - for k, v := range c.activityLog.entityTracker.activeEntities { + for k, v := range c.activityLog.clientTracker.activeClients { out[k] = v } c.activityLog.fragmentLock.RUnlock() @@ -57,7 +58,7 @@ func (c *Core) GetActiveEntities() map[string]struct{} { func (a *ActivityLog) GetCurrentEntities() *activity.EntityActivityLog { a.l.RLock() defer a.l.RUnlock() - return a.currentSegment.currentEntities + return a.currentSegment.currentClients } // WriteToStorage is used to put entity data in storage @@ -90,6 +91,29 @@ func (a *ActivityLog) SetStandbyEnable(ctx context.Context, enabled bool) { }) } +// NOTE: AddTokenToFragment is deprecated and can no longer be used, except for +// testing backward compatibility. Please use AddClientToFragment instead. +func (a *ActivityLog) AddTokenToFragment(namespaceID string) { + a.fragmentLock.Lock() + defer a.fragmentLock.Unlock() + + if !a.enabled { + return + } + + a.createCurrentFragment() + + a.fragment.NonEntityTokens[namespaceID] += 1 +} + +func RandStringBytes(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = byte(rand.Intn(26)) + 'a' + } + return string(b) +} + // ExpectCurrentSegmentRefreshed verifies that the current segment has been refreshed // non-nil empty components and updated with the `expectedStart` timestamp // Note: if `verifyTimeNotZero` is true, ignore `expectedStart` and just make sure the timestamp isn't 0 @@ -100,14 +124,11 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart defer a.l.RUnlock() a.fragmentLock.RLock() defer a.fragmentLock.RUnlock() - if a.currentSegment.currentEntities == nil { - t.Fatalf("expected non-nil currentSegment.currentEntities") + if a.currentSegment.currentClients == nil { + t.Fatalf("expected non-nil currentSegment.currentClients") } - if a.currentSegment.currentEntities.Entities == nil { - t.Errorf("expected non-nil currentSegment.currentEntities.Entities") - } - if a.entityTracker.activeEntities == nil { - t.Errorf("expected non-nil activeEntities") + if a.currentSegment.currentClients.Clients == nil { + t.Errorf("expected non-nil currentSegment.currentClients.Entities") } if a.currentSegment.tokenCount == nil { t.Fatalf("expected non-nil currentSegment.tokenCount") @@ -115,16 +136,18 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart if a.currentSegment.tokenCount.CountByNamespaceID == nil { t.Errorf("expected non-nil currentSegment.tokenCount.CountByNamespaceID") } - - if len(a.currentSegment.currentEntities.Entities) > 0 { - t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentSegment.currentEntities) + if a.clientTracker.activeClients == nil { + t.Errorf("expected non-nil activeClients") } - if len(a.entityTracker.activeEntities) > 0 { - t.Errorf("expected no active entity segment to be loaded. got: %v", a.entityTracker.activeEntities) + if len(a.currentSegment.currentClients.Clients) > 0 { + t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentSegment.currentClients) } if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { t.Errorf("expected no token counts to be loaded. got: %v", a.currentSegment.tokenCount.CountByNamespaceID) } + if len(a.clientTracker.activeClients) > 0 { + t.Errorf("expected no active entity segment to be loaded. got: %v", a.clientTracker.activeClients) + } if verifyTimeNotZero { if a.currentSegment.startTimestamp == 0 { @@ -142,7 +165,7 @@ func ActiveEntitiesEqual(active map[string]struct{}, test []*activity.EntityReco } for _, ent := range test { - if _, ok := active[ent.EntityID]; !ok { + if _, ok := active[ent.ClientID]; !ok { return false } } @@ -164,15 +187,8 @@ func (a *ActivityLog) SetStartTimestamp(timestamp int64) { a.currentSegment.startTimestamp = timestamp } -// SetTokenCount sets the tokenCount on an activity log -func (a *ActivityLog) SetTokenCount(tokenCount *activity.TokenCount) { - a.l.Lock() - defer a.l.Unlock() - a.currentSegment.tokenCount = tokenCount -} - -// GetCountByNamespaceID returns the count of tokens by namespace ID -func (a *ActivityLog) GetCountByNamespaceID() map[string]uint64 { +// GetStoredTokenCountByNamespaceID returns the count of tokens by namespace ID +func (a *ActivityLog) GetStoredTokenCountByNamespaceID() map[string]uint64 { a.l.RLock() defer a.l.RUnlock() return a.currentSegment.tokenCount.CountByNamespaceID @@ -182,7 +198,7 @@ func (a *ActivityLog) GetCountByNamespaceID() map[string]uint64 { func (a *ActivityLog) GetEntitySequenceNumber() uint64 { a.l.RLock() defer a.l.RUnlock() - return a.currentSegment.entitySequenceNumber + return a.currentSegment.clientSequenceNumber } // SetEnable sets the enabled flag on the activity log diff --git a/vault/core.go b/vault/core.go index 39f7b6da7..ce4ff1a86 100644 --- a/vault/core.go +++ b/vault/core.go @@ -44,6 +44,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" + "github.com/hashicorp/vault/sdk/version" sr "github.com/hashicorp/vault/serviceregistration" "github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/vault/cluster" @@ -568,6 +569,10 @@ type Core struct { // enable/disable identifying response headers enableResponseHeaderHostname bool enableResponseHeaderRaftNodeID bool + + // VersionTimestamps is a map of vault versions to timestamps when the version + // was first run + VersionTimestamps map[string]time.Time } func (c *Core) HAState() consts.HAState { @@ -1032,9 +1037,33 @@ func NewCore(conf *CoreConfig) (*Core, error) { return nil, err } + if c.VersionTimestamps == nil { + c.logger.Info("Initializing VersionTimestamps for core") + c.VersionTimestamps = make(map[string]time.Time) + } + return c, nil } +// HandleVersionTimeStamps stores the current version at the current time to +// storage, and then loads all versions and upgrade timestamps out from storage. +func (c *Core) HandleVersionTimeStamps(ctx context.Context) error { + currentTime := time.Now() + isUpdated, err := c.StoreVersionTimestamp(ctx, version.Version, currentTime) + if err != nil { + return err + } + if isUpdated { + c.logger.Info("Recorded vault version", "vault version", version.Version, "upgrade time", currentTime) + } + // Finally, load the versions into core fields + err = c.HandleLoadVersionTimestamps(ctx) + if err != nil { + return err + } + return nil +} + // HostnameHeaderEnabled determines whether to add the X-Vault-Hostname header // to HTTP responses. func (c *Core) HostnameHeaderEnabled() bool { @@ -2134,6 +2163,11 @@ func (c *Core) postUnseal(ctx context.Context, ctxCancelFunc context.CancelFunc, c.logger.Warn("post-unseal post seal migration failed", "error", err) } } + err := c.HandleVersionTimeStamps(c.activeContext) + if err != nil { + c.logger.Warn("post-unseal version timestamp setup failed", "error", err) + + } c.logger.Info("post-unseal setup complete") return nil diff --git a/vault/core_test.go b/vault/core_test.go index ce1688212..a0c59d811 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -2,6 +2,7 @@ package vault import ( "context" + "fmt" "reflect" "sync" "testing" @@ -53,6 +54,23 @@ func TestSealConfig_Invalid(t *testing.T) { } } +// TestCore_HasVaultVersion checks that VersionTimestamps are correct and initialized +// after a core has been unsealed. +func TestCore_HasVaultVersion(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + if c.VersionTimestamps == nil { + t.Fatalf("Version timestamps for core were not initialized for a new core") + } + upgradeTime, ok := c.VersionTimestamps["1.9.0"] + if !ok { + t.Fatalf("1.9.0 upgrade time not found") + } + if upgradeTime.After(time.Now()) || upgradeTime.Before(time.Now().Add(-1*time.Hour)) { + t.Fatalf("upgrade time isn't within reasonable bounds of new core initialization. " + + fmt.Sprintf("time is: %+v, upgrade time is %+v", time.Now(), upgradeTime)) + } +} + func TestCore_Unseal_MultiShare(t *testing.T) { c := TestCore(t) diff --git a/vault/core_util_common.go b/vault/core_util_common.go new file mode 100644 index 000000000..a70cac733 --- /dev/null +++ b/vault/core_util_common.go @@ -0,0 +1,110 @@ +package vault + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/hashicorp/vault/sdk/logical" +) + +const vaultVersionPath string = "core/versions/" + +// StoreVersionTimestamp will store the version and timestamp pair to storage only if no entry +// for that version already exists in storage. +func (c *Core) StoreVersionTimestamp(ctx context.Context, version string, currentTime time.Time) (bool, error) { + timeStamp, err := c.barrier.Get(ctx, vaultVersionPath+version) + if err != nil { + return false, err + } + + if timeStamp != nil { + return false, nil + } + + vaultVersion := VaultVersion{TimestampInstalled: currentTime, Version: version} + marshalledVaultVersion, err := json.Marshal(vaultVersion) + if err != nil { + return false, err + } + + err = c.barrier.Put(ctx, &logical.StorageEntry{ + Key: vaultVersionPath + version, + Value: marshalledVaultVersion, + }) + if err != nil { + return false, err + } + return true, nil +} + +// FindMostRecentVersionTimestamp loads the current vault version and associated +// upgrade time from storage. +func (c *Core) FindMostRecentVersionTimestamp() (string, time.Time, error) { + if c.VersionTimestamps == nil || len(c.VersionTimestamps) == 0 { + return "", time.Time{}, fmt.Errorf("Version timestamps are not initialized") + } + var latestUpgradeTime time.Time + var mostRecentVersion string + for version, upgradeTime := range c.VersionTimestamps { + if upgradeTime.After(latestUpgradeTime) { + mostRecentVersion = version + latestUpgradeTime = upgradeTime + } + } + // This if-case should never be hit + if mostRecentVersion == "" { + return "", latestUpgradeTime, fmt.Errorf("Empty vault version was written to storage at time: %+v", latestUpgradeTime) + } + return mostRecentVersion, latestUpgradeTime, nil +} + +// FindOldestVersionTimestamp searches for the vault version with the oldest +// upgrade timestamp from storage. The earliest version this can be (barring +// downgrades) is 1.9.0. +func (c *Core) FindOldestVersionTimestamp() (string, time.Time, error) { + if c.VersionTimestamps == nil || len(c.VersionTimestamps) == 0 { + return "", time.Time{}, fmt.Errorf("version timestamps are not initialized") + } + + // initialize oldestUpgradeTime to current time + oldestUpgradeTime := time.Now() + var oldestVersion string + for version, upgradeTime := range c.VersionTimestamps { + if upgradeTime.Before(oldestUpgradeTime) { + oldestVersion = version + oldestUpgradeTime = upgradeTime + } + } + return oldestVersion, oldestUpgradeTime, nil +} + +// HandleLoadVersionTimestamps loads all the vault versions and associated +// upgrade timestamps from storage. +func (c *Core) HandleLoadVersionTimestamps(ctx context.Context) (retErr error) { + vaultVersions, err := c.barrier.List(ctx, vaultVersionPath) + if err != nil { + return fmt.Errorf("unable to retrieve vault versions from storage: %+w", err) + } + + for _, versionPath := range vaultVersions { + version, err := c.barrier.Get(ctx, vaultVersionPath+versionPath) + if err != nil { + return fmt.Errorf("unable to read vault version at path %s: err %+w", versionPath, err) + } + if version == nil { + return fmt.Errorf("nil version stored at path %s", versionPath) + } + var vaultVersion VaultVersion + err = json.Unmarshal(version.Value, &vaultVersion) + if err != nil { + return fmt.Errorf("unable to unmarshal vault version for path %s: err %w", versionPath, err) + } + if vaultVersion.Version == "" || vaultVersion.TimestampInstalled.IsZero() { + return fmt.Errorf("found empty serialized vault version at path %s", versionPath) + } + c.VersionTimestamps[vaultVersion.Version] = vaultVersion.TimestampInstalled + } + return nil +} diff --git a/vault/core_util_common_test.go b/vault/core_util_common_test.go new file mode 100644 index 000000000..e47912e3d --- /dev/null +++ b/vault/core_util_common_test.go @@ -0,0 +1,48 @@ +package vault + +import ( + "context" + "testing" + "time" +) + +// TestStoreMultipleVaultVersions writes multiple versions of 1.9.0 and verifies that only +// the original timestamp is stored. +func TestStoreMultipleVaultVersions(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + upgradeTimePlusEpsilon := time.Now() + wasStored, err := c.StoreVersionTimestamp(context.Background(), "1.9.0", upgradeTimePlusEpsilon.Add(30*time.Hour)) + if err != nil || wasStored { + t.Fatalf("vault version was re-stored: %v, err is: %s", wasStored, err.Error()) + } + upgradeTime, ok := c.VersionTimestamps["1.9.0"] + if !ok { + t.Fatalf("no 1.9.0 version timestamp found") + } + if upgradeTime.After(upgradeTimePlusEpsilon) { + t.Fatalf("upgrade time for 1.9.0 is incorrect: got %+v, expected less than %+v", upgradeTime, upgradeTimePlusEpsilon) + } +} + +// TestGetOldestVersion verifies that FindOldestVersionTimestamp finds the oldest +// vault version stored. +func TestGetOldestVersion(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + upgradeTimePlusEpsilon := time.Now() + c.StoreVersionTimestamp(context.Background(), "1.9.1", upgradeTimePlusEpsilon.Add(-4*time.Hour)) + c.StoreVersionTimestamp(context.Background(), "1.9.2", upgradeTimePlusEpsilon.Add(2*time.Hour)) + c.HandleLoadVersionTimestamps(c.activeContext) + if len(c.VersionTimestamps) != 3 { + t.Fatalf("expected 3 entries in timestamps map after refresh, found: %d", len(c.VersionTimestamps)) + } + v, tm, err := c.FindOldestVersionTimestamp() + if err != nil { + t.Fatal(err) + } + if v != "1.9.1" { + t.Fatalf("expected 1.9.1, found: %s", v) + } + if tm.Before(upgradeTimePlusEpsilon.Add(-6*time.Hour)) || tm.After(upgradeTimePlusEpsilon.Add(-2*time.Hour)) { + t.Fatalf("incorrect upgrade time logged: %v", tm) + } +} diff --git a/vault/request_handling.go b/vault/request_handling.go index ba53c82fe..16e233e8f 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -396,10 +396,11 @@ func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool return auth, te, retErr } - // If it is an authenticated ( i.e with vault token ) request - // associated with an entity, increment client count - if !unauth && c.activityLog != nil && te.EntityID != "" { - c.activityLog.HandleTokenCreation(te) + // If it is an authenticated ( i.e with vault token ) request, increment client count + if !unauth && c.activityLog != nil { + clientID, _ := c.activityLog.CreateClientID(req.TokenEntry()) + req.ClientID = clientID + c.activityLog.HandleTokenUsage(te) } return auth, te, nil } diff --git a/vault/token_store.go b/vault/token_store.go index 8a5079058..85d82244a 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -508,8 +508,7 @@ type TokenStore struct { parentBarrierView *BarrierView rolesBarrierView *BarrierView - expiration *ExpirationManager - activityLog *ActivityLog + expiration *ExpirationManager cubbyholeBackend *CubbyholeBackend @@ -686,12 +685,6 @@ func (ts *TokenStore) SetExpirationManager(exp *ExpirationManager) { ts.expiration = exp } -// SetActivityLog injects the activity log to which all new -// token creation events are reported. -func (ts *TokenStore) SetActivityLog(a *ActivityLog) { - ts.activityLog = a -} - // SaltID is used to apply a salt and hash to an ID to make sure its not reversible func (ts *TokenStore) SaltID(ctx context.Context, id string) (string, error) { ns, err := namespace.FromContext(ctx) @@ -910,11 +903,6 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err return err } - // Update the activity log in case the token has no entity - if ts.activityLog != nil && entry.EntityID == "" { - ts.activityLog.HandleTokenCreation(entry) - } - return ts.storeCommon(ctx, entry, true) case logical.TokenTypeBatch: @@ -961,11 +949,6 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err entry.ID = fmt.Sprintf("%s.%s", entry.ID, tokenNS.ID) } - // Update the activity log in case the token has no entity - if ts.activityLog != nil && entry.EntityID == "" { - ts.activityLog.HandleTokenCreation(entry) - } - return nil default: diff --git a/vault/vault_version_time.go b/vault/vault_version_time.go new file mode 100644 index 000000000..40f3f2813 --- /dev/null +++ b/vault/vault_version_time.go @@ -0,0 +1,8 @@ +package vault + +import "time" + +type VaultVersion struct { + TimestampInstalled time.Time + Version string +}