[VAULT-3157] Move `mergeStates` utils from Agent to api module (#12731)
* move merge and compare states to vault core * move MergeState, CompareStates and ParseRequiredStates to api package * fix merge state reference in API Proxy * move mergeStates test to api package * add changelog * ghost commit to trigger CI * rename CompareStates to CompareReplicationStates * rename MergeStates and make compareStates and parseStates private methods * improved error messaging in parseReplicationState * export ParseReplicationState for enterprise files
This commit is contained in:
parent
79662d0842
commit
458927c2ed
106
api/client.go
106
api/client.go
|
@ -2,7 +2,11 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -21,6 +25,8 @@ import (
|
|||
rootcerts "github.com/hashicorp/go-rootcerts"
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
@ -1161,6 +1167,106 @@ func RequireState(states ...string) RequestCallback {
|
|||
}
|
||||
}
|
||||
|
||||
// compareReplicationStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
|
||||
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
|
||||
// are invalid or from different clusters.
|
||||
func compareReplicationStates(s1, s2 string) (int, error) {
|
||||
w1, err := ParseReplicationState(s1, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w2, err := ParseReplicationState(s2, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w1.ClusterID != w2.ClusterID {
|
||||
return 0, fmt.Errorf("can't compare replication states with different ClusterIDs")
|
||||
}
|
||||
|
||||
switch {
|
||||
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
|
||||
return 1, nil
|
||||
// We've already handled the case where both are equal above, so really we're
|
||||
// asking here if one or both are lesser.
|
||||
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// MergeReplicationStates returns a merged array of replication states by iterating
|
||||
// through all states in `old`. An iterated state is merged to the result before `new`
|
||||
// based on the result of compareReplicationStates
|
||||
func MergeReplicationStates(old []string, new string) []string {
|
||||
if len(old) == 0 || len(old) > 2 {
|
||||
return []string{new}
|
||||
}
|
||||
|
||||
var ret []string
|
||||
for _, o := range old {
|
||||
c, err := compareReplicationStates(o, new)
|
||||
if err != nil {
|
||||
return []string{new}
|
||||
}
|
||||
switch c {
|
||||
case 1:
|
||||
ret = append(ret, o)
|
||||
case -1:
|
||||
ret = append(ret, new)
|
||||
case 0:
|
||||
ret = append(ret, o, new)
|
||||
}
|
||||
}
|
||||
return strutil.RemoveDuplicates(ret, false)
|
||||
}
|
||||
|
||||
func ParseReplicationState(raw string, hmacKey []byte) (*logical.WALState, error) {
|
||||
cooked, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := string(cooked)
|
||||
|
||||
lastIndex := strings.LastIndexByte(s, ':')
|
||||
if lastIndex == -1 {
|
||||
return nil, fmt.Errorf("invalid full state header format")
|
||||
}
|
||||
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
|
||||
stateHMAC, err := hex.DecodeString(stateHMACRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
|
||||
}
|
||||
|
||||
if len(hmacKey) != 0 {
|
||||
hm := hmac.New(sha256.New, hmacKey)
|
||||
hm.Write([]byte(state))
|
||||
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
|
||||
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
|
||||
}
|
||||
}
|
||||
|
||||
pieces := strings.Split(state, ":")
|
||||
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid local index in state header: %w", err)
|
||||
}
|
||||
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid replicated index in state header: %w", err)
|
||||
}
|
||||
|
||||
return &logical.WALState{
|
||||
ClusterID: pieces[1],
|
||||
LocalIndex: localIndex,
|
||||
ReplicatedIndex: replicatedIndex,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardInconsistent returns a request callback that will add a request
|
||||
// header which says: if the state required isn't present on the node receiving
|
||||
// this request, forward it to the active node. This should be used in
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
)
|
||||
|
@ -562,3 +564,85 @@ func TestSetHeadersRaceSafe(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeReplicationStates(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
old []string
|
||||
new string
|
||||
expected []string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "empty-old",
|
||||
old: nil,
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "old-smaller",
|
||||
old: []string{"v1:cid:1:0:"},
|
||||
new: "v1:cid:2:0:",
|
||||
expected: []string{"v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "old-bigger",
|
||||
old: []string{"v1:cid:2:0:"},
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-single",
|
||||
old: []string{"v1:cid:1:0:"},
|
||||
new: "v1:cid:0:1:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-single-alt",
|
||||
old: []string{"v1:cid:0:1:"},
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-double",
|
||||
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
new: "v1:cid:2:0:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "newer-both",
|
||||
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
new: "v1:cid:2:1:",
|
||||
expected: []string{"v1:cid:2:1:"},
|
||||
},
|
||||
}
|
||||
|
||||
b64enc := func(ss []string) []string {
|
||||
var ret []string
|
||||
for _, s := range ss {
|
||||
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
b64dec := func(ss []string) []string {
|
||||
var ret []string
|
||||
for _, s := range ss {
|
||||
d, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ret = append(ret, string(d))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
|
||||
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
|
||||
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
api: Move mergeStates and other required utils from agent to api module
|
||||
```
|
|
@ -7,10 +7,8 @@ import (
|
|||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
type EnforceConsistency int
|
||||
|
@ -60,58 +58,6 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
|
||||
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
|
||||
// are invalid or from different clusters.
|
||||
func compareStates(s1, s2 string) (int, error) {
|
||||
w1, err := vault.ParseRequiredState(s1, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w2, err := vault.ParseRequiredState(s2, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w1.ClusterID != w2.ClusterID {
|
||||
return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs")
|
||||
}
|
||||
|
||||
switch {
|
||||
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
|
||||
return 1, nil
|
||||
// We've already handled the case where both are equal above, so really we're
|
||||
// asking here if one or both are lesser.
|
||||
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func mergeStates(old []string, new string) []string {
|
||||
if len(old) == 0 || len(old) > 2 {
|
||||
return []string{new}
|
||||
}
|
||||
|
||||
var ret []string
|
||||
for _, o := range old {
|
||||
c, err := compareStates(o, new)
|
||||
if err != nil {
|
||||
return []string{new}
|
||||
}
|
||||
switch c {
|
||||
case 1:
|
||||
ret = append(ret, o)
|
||||
case -1:
|
||||
ret = append(ret, new)
|
||||
case 0:
|
||||
ret = append(ret, o, new)
|
||||
}
|
||||
}
|
||||
return strutil.RemoveDuplicates(ret, false)
|
||||
}
|
||||
|
||||
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
client, err := ap.client.Clone()
|
||||
if err != nil {
|
||||
|
@ -184,7 +130,7 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
// Vault uses to do so. So instead we compare any of the 0-2 saved states
|
||||
// we have to the new header, keeping the newest 1-2 of these, and sending
|
||||
// them to Vault to evaluate.
|
||||
ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState)
|
||||
ap.lastIndexStates = api.MergeReplicationStates(ap.lastIndexStates, newState)
|
||||
ap.l.Unlock()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
|
@ -96,85 +93,3 @@ func TestAPIProxy_queryParams(t *testing.T) {
|
|||
t.Fatalf("exptected standby to return 200, got: %v", resp.Response.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeStates(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
old []string
|
||||
new string
|
||||
expected []string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "empty-old",
|
||||
old: nil,
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "old-smaller",
|
||||
old: []string{"v1:cid:1:0:"},
|
||||
new: "v1:cid:2:0:",
|
||||
expected: []string{"v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "old-bigger",
|
||||
old: []string{"v1:cid:2:0:"},
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-single",
|
||||
old: []string{"v1:cid:1:0:"},
|
||||
new: "v1:cid:0:1:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-single-alt",
|
||||
old: []string{"v1:cid:0:1:"},
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-double",
|
||||
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
new: "v1:cid:2:0:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "newer-both",
|
||||
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
new: "v1:cid:2:1:",
|
||||
expected: []string{"v1:cid:2:1:"},
|
||||
},
|
||||
}
|
||||
|
||||
b64enc := func(ss []string) []string {
|
||||
var ret []string
|
||||
for _, s := range ss {
|
||||
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
b64dec := func(ss []string) []string {
|
||||
var ret []string
|
||||
for _, s := range ss {
|
||||
d, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ret = append(ret, string(d))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out := b64dec(mergeStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
|
||||
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
|
||||
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,14 +3,10 @@ package vault
|
|||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -20,7 +16,6 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -2824,51 +2819,6 @@ func (c *Core) isPrimary() bool {
|
|||
return !c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary | consts.ReplicationDRSecondary)
|
||||
}
|
||||
|
||||
func ParseRequiredState(raw string, hmacKey []byte) (*logical.WALState, error) {
|
||||
cooked, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := string(cooked)
|
||||
|
||||
lastIndex := strings.LastIndexByte(s, ':')
|
||||
if lastIndex == -1 {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
|
||||
stateHMAC, err := hex.DecodeString(stateHMACRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
|
||||
}
|
||||
|
||||
if len(hmacKey) != 0 {
|
||||
hm := hmac.New(sha256.New, hmacKey)
|
||||
hm.Write([]byte(state))
|
||||
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
|
||||
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
|
||||
}
|
||||
}
|
||||
|
||||
pieces := strings.Split(state, ":")
|
||||
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
|
||||
return &logical.WALState{
|
||||
ClusterID: pieces[1],
|
||||
LocalIndex: localIndex,
|
||||
ReplicatedIndex: replicatedIndex,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type LicenseState struct {
|
||||
State string
|
||||
ExpiryTime time.Time
|
||||
|
|
Loading…
Reference in New Issue