[VAULT-3157] Move `mergeStates` utils from Agent to api module (#12731)

* move merge and compare states to vault core

* move MergeState, CompareStates and ParseRequiredStates to api package

* fix merge state reference in API Proxy

* move mergeStates test to api package

* add changelog

* ghost commit to trigger CI

* rename CompareStates to CompareReplicationStates

* rename MergeStates and make compareStates and parseStates private methods

* improved error messaging in parseReplicationState

* export ParseReplicationState for enterprise files
This commit is contained in:
vinay-gopalan 2021-10-06 10:57:06 -07:00 committed by GitHub
parent 79662d0842
commit 458927c2ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 194 additions and 190 deletions

View File

@ -2,7 +2,11 @@ package api
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/http"
@ -21,6 +25,8 @@ import (
rootcerts "github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
)
@ -1161,6 +1167,106 @@ func RequireState(states ...string) RequestCallback {
}
}
// compareReplicationStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
// are invalid or from different clusters.
func compareReplicationStates(s1, s2 string) (int, error) {
w1, err := ParseReplicationState(s1, nil)
if err != nil {
return 0, err
}
w2, err := ParseReplicationState(s2, nil)
if err != nil {
return 0, err
}
if w1.ClusterID != w2.ClusterID {
return 0, fmt.Errorf("can't compare replication states with different ClusterIDs")
}
switch {
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
return 1, nil
// We've already handled the case where both are equal above, so really we're
// asking here if one or both are lesser.
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
return -1, nil
}
return 0, nil
}
// MergeReplicationStates returns a merged array of replication states by iterating
// through all states in `old`. An iterated state is merged to the result before `new`
// based on the result of compareReplicationStates
func MergeReplicationStates(old []string, new string) []string {
if len(old) == 0 || len(old) > 2 {
return []string{new}
}
var ret []string
for _, o := range old {
c, err := compareReplicationStates(o, new)
if err != nil {
return []string{new}
}
switch c {
case 1:
ret = append(ret, o)
case -1:
ret = append(ret, new)
case 0:
ret = append(ret, o, new)
}
}
return strutil.RemoveDuplicates(ret, false)
}
func ParseReplicationState(raw string, hmacKey []byte) (*logical.WALState, error) {
cooked, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
s := string(cooked)
lastIndex := strings.LastIndexByte(s, ':')
if lastIndex == -1 {
return nil, fmt.Errorf("invalid full state header format")
}
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
stateHMAC, err := hex.DecodeString(stateHMACRaw)
if err != nil {
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
}
if len(hmacKey) != 0 {
hm := hmac.New(sha256.New, hmacKey)
hm.Write([]byte(state))
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
}
}
pieces := strings.Split(state, ":")
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
return nil, fmt.Errorf("invalid state header format")
}
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid local index in state header: %w", err)
}
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid replicated index in state header: %w", err)
}
return &logical.WALState{
ClusterID: pieces[1],
LocalIndex: localIndex,
ReplicatedIndex: replicatedIndex,
}, nil
}
// ForwardInconsistent returns a request callback that will add a request
// header which says: if the state required isn't present on the node receiving
// this request, forward it to the active node. This should be used in

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net/http"
@ -14,6 +15,7 @@ import (
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/consts"
)
@ -562,3 +564,85 @@ func TestSetHeadersRaceSafe(t *testing.T) {
}
}
}
func TestMergeReplicationStates(t *testing.T) {
type testCase struct {
name string
old []string
new string
expected []string
}
testCases := []testCase{
{
name: "empty-old",
old: nil,
new: "v1:cid:1:0:",
expected: []string{"v1:cid:1:0:"},
},
{
name: "old-smaller",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "old-bigger",
old: []string{"v1:cid:2:0:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "mixed-single",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:0:1:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-single-alt",
old: []string{"v1:cid:0:1:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-double",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
},
{
name: "newer-both",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:1:",
expected: []string{"v1:cid:2:1:"},
},
}
b64enc := func(ss []string) []string {
var ret []string
for _, s := range ss {
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
}
return ret
}
b64dec := func(ss []string) []string {
var ret []string
for _, s := range ss {
d, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
ret = append(ret, string(d))
}
return ret
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
}
})
}
}

3
changelog/12731.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
api: Move mergeStates and other required utils from agent to api module
```

View File

@ -7,10 +7,8 @@ import (
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
)
type EnforceConsistency int
@ -60,58 +58,6 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
}, nil
}
// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
// are invalid or from different clusters.
func compareStates(s1, s2 string) (int, error) {
w1, err := vault.ParseRequiredState(s1, nil)
if err != nil {
return 0, err
}
w2, err := vault.ParseRequiredState(s2, nil)
if err != nil {
return 0, err
}
if w1.ClusterID != w2.ClusterID {
return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs")
}
switch {
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
return 1, nil
// We've already handled the case where both are equal above, so really we're
// asking here if one or both are lesser.
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
return -1, nil
}
return 0, nil
}
func mergeStates(old []string, new string) []string {
if len(old) == 0 || len(old) > 2 {
return []string{new}
}
var ret []string
for _, o := range old {
c, err := compareStates(o, new)
if err != nil {
return []string{new}
}
switch c {
case 1:
ret = append(ret, o)
case -1:
ret = append(ret, new)
case 0:
ret = append(ret, o, new)
}
}
return strutil.RemoveDuplicates(ret, false)
}
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
client, err := ap.client.Clone()
if err != nil {
@ -184,7 +130,7 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
// Vault uses to do so. So instead we compare any of the 0-2 saved states
// we have to the new header, keeping the newest 1-2 of these, and sending
// them to Vault to evaluate.
ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState)
ap.lastIndexStates = api.MergeReplicationStates(ap.lastIndexStates, newState)
ap.l.Unlock()
}

View File

@ -1,12 +1,9 @@
package cache
import (
"encoding/base64"
"net/http"
"testing"
"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/namespace"
@ -96,85 +93,3 @@ func TestAPIProxy_queryParams(t *testing.T) {
t.Fatalf("exptected standby to return 200, got: %v", resp.Response.StatusCode)
}
}
func TestMergeStates(t *testing.T) {
type testCase struct {
name string
old []string
new string
expected []string
}
testCases := []testCase{
{
name: "empty-old",
old: nil,
new: "v1:cid:1:0:",
expected: []string{"v1:cid:1:0:"},
},
{
name: "old-smaller",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "old-bigger",
old: []string{"v1:cid:2:0:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "mixed-single",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:0:1:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-single-alt",
old: []string{"v1:cid:0:1:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-double",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
},
{
name: "newer-both",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:1:",
expected: []string{"v1:cid:2:1:"},
},
}
b64enc := func(ss []string) []string {
var ret []string
for _, s := range ss {
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
}
return ret
}
b64dec := func(ss []string) []string {
var ret []string
for _, s := range ss {
d, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
ret = append(ret, string(d))
}
return ret
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out := b64dec(mergeStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
}
})
}
}

View File

@ -3,14 +3,10 @@ package vault
import (
"context"
"crypto/ecdsa"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@ -20,7 +16,6 @@ import (
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -2824,51 +2819,6 @@ func (c *Core) isPrimary() bool {
return !c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary | consts.ReplicationDRSecondary)
}
func ParseRequiredState(raw string, hmacKey []byte) (*logical.WALState, error) {
cooked, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
s := string(cooked)
lastIndex := strings.LastIndexByte(s, ':')
if lastIndex == -1 {
return nil, fmt.Errorf("invalid state header format")
}
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
stateHMAC, err := hex.DecodeString(stateHMACRaw)
if err != nil {
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
}
if len(hmacKey) != 0 {
hm := hmac.New(sha256.New, hmacKey)
hm.Write([]byte(state))
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
}
}
pieces := strings.Split(state, ":")
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
return nil, fmt.Errorf("invalid state header format")
}
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid state header format")
}
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid state header format")
}
return &logical.WALState{
ClusterID: pieces[1],
LocalIndex: localIndex,
ReplicatedIndex: replicatedIndex,
}, nil
}
type LicenseState struct {
State string
ExpiryTime time.Time