diff --git a/api/client.go b/api/client.go index df8cfa551..34974d742 100644 --- a/api/client.go +++ b/api/client.go @@ -2,7 +2,11 @@ package api import ( "context" + "crypto/hmac" + "crypto/sha256" "crypto/tls" + "encoding/base64" + "encoding/hex" "fmt" "net" "net/http" @@ -21,6 +25,8 @@ import ( rootcerts "github.com/hashicorp/go-rootcerts" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/logical" "golang.org/x/net/http2" "golang.org/x/time/rate" ) @@ -1161,6 +1167,106 @@ func RequireState(states ...string) RequestCallback { } } +// compareReplicationStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0 +// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2 +// are invalid or from different clusters. +func compareReplicationStates(s1, s2 string) (int, error) { + w1, err := ParseReplicationState(s1, nil) + if err != nil { + return 0, err + } + w2, err := ParseReplicationState(s2, nil) + if err != nil { + return 0, err + } + + if w1.ClusterID != w2.ClusterID { + return 0, fmt.Errorf("can't compare replication states with different ClusterIDs") + } + + switch { + case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex: + return 1, nil + // We've already handled the case where both are equal above, so really we're + // asking here if one or both are lesser. + case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex: + return -1, nil + } + + return 0, nil +} + +// MergeReplicationStates returns a merged array of replication states by iterating +// through all states in `old`. An iterated state is merged to the result before `new` +// based on the result of compareReplicationStates +func MergeReplicationStates(old []string, new string) []string { + if len(old) == 0 || len(old) > 2 { + return []string{new} + } + + var ret []string + for _, o := range old { + c, err := compareReplicationStates(o, new) + if err != nil { + return []string{new} + } + switch c { + case 1: + ret = append(ret, o) + case -1: + ret = append(ret, new) + case 0: + ret = append(ret, o, new) + } + } + return strutil.RemoveDuplicates(ret, false) +} + +func ParseReplicationState(raw string, hmacKey []byte) (*logical.WALState, error) { + cooked, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, err + } + s := string(cooked) + + lastIndex := strings.LastIndexByte(s, ':') + if lastIndex == -1 { + return nil, fmt.Errorf("invalid full state header format") + } + state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:] + stateHMAC, err := hex.DecodeString(stateHMACRaw) + if err != nil { + return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err) + } + + if len(hmacKey) != 0 { + hm := hmac.New(sha256.New, hmacKey) + hm.Write([]byte(state)) + if !hmac.Equal(hm.Sum(nil), stateHMAC) { + return nil, fmt.Errorf("invalid state header HMAC (mismatch)") + } + } + + pieces := strings.Split(state, ":") + if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" { + return nil, fmt.Errorf("invalid state header format") + } + localIndex, err := strconv.ParseUint(pieces[2], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid local index in state header: %w", err) + } + replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid replicated index in state header: %w", err) + } + + return &logical.WALState{ + ClusterID: pieces[1], + LocalIndex: localIndex, + ReplicatedIndex: replicatedIndex, + }, nil +} + // ForwardInconsistent returns a request callback that will add a request // header which says: if the state required isn't present on the node receiving // this request, forward it to the active node. This should be used in diff --git a/api/client_test.go b/api/client_test.go index 474db04e1..3b306a927 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/x509" + "encoding/base64" "fmt" "io" "net/http" @@ -14,6 +15,7 @@ import ( "testing" "time" + "github.com/go-test/deep" "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/helper/consts" ) @@ -562,3 +564,85 @@ func TestSetHeadersRaceSafe(t *testing.T) { } } } + +func TestMergeReplicationStates(t *testing.T) { + type testCase struct { + name string + old []string + new string + expected []string + } + + testCases := []testCase{ + { + name: "empty-old", + old: nil, + new: "v1:cid:1:0:", + expected: []string{"v1:cid:1:0:"}, + }, + { + name: "old-smaller", + old: []string{"v1:cid:1:0:"}, + new: "v1:cid:2:0:", + expected: []string{"v1:cid:2:0:"}, + }, + { + name: "old-bigger", + old: []string{"v1:cid:2:0:"}, + new: "v1:cid:1:0:", + expected: []string{"v1:cid:2:0:"}, + }, + { + name: "mixed-single", + old: []string{"v1:cid:1:0:"}, + new: "v1:cid:0:1:", + expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + }, + { + name: "mixed-single-alt", + old: []string{"v1:cid:0:1:"}, + new: "v1:cid:1:0:", + expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + }, + { + name: "mixed-double", + old: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + new: "v1:cid:2:0:", + expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"}, + }, + { + name: "newer-both", + old: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + new: "v1:cid:2:1:", + expected: []string{"v1:cid:2:1:"}, + }, + } + + b64enc := func(ss []string) []string { + var ret []string + for _, s := range ss { + ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s))) + } + return ret + } + b64dec := func(ss []string) []string { + var ret []string + for _, s := range ss { + d, err := base64.StdEncoding.DecodeString(s) + if err != nil { + t.Fatal(err) + } + ret = append(ret, string(d)) + } + return ret + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new)))) + if diff := deep.Equal(out, tc.expected); len(diff) != 0 { + t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff) + } + }) + } +} diff --git a/changelog/12731.txt b/changelog/12731.txt new file mode 100644 index 000000000..c88366879 --- /dev/null +++ b/changelog/12731.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Move mergeStates and other required utils from agent to api module +``` \ No newline at end of file diff --git a/command/agent/cache/api_proxy.go b/command/agent/cache/api_proxy.go index 182220129..9523d310b 100644 --- a/command/agent/cache/api_proxy.go +++ b/command/agent/cache/api_proxy.go @@ -7,10 +7,8 @@ import ( hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-retryablehttp" - "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/vault" ) type EnforceConsistency int @@ -60,58 +58,6 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) { }, nil } -// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0 -// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2 -// are invalid or from different clusters. -func compareStates(s1, s2 string) (int, error) { - w1, err := vault.ParseRequiredState(s1, nil) - if err != nil { - return 0, err - } - w2, err := vault.ParseRequiredState(s2, nil) - if err != nil { - return 0, err - } - - if w1.ClusterID != w2.ClusterID { - return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs") - } - - switch { - case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex: - return 1, nil - // We've already handled the case where both are equal above, so really we're - // asking here if one or both are lesser. - case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex: - return -1, nil - } - - return 0, nil -} - -func mergeStates(old []string, new string) []string { - if len(old) == 0 || len(old) > 2 { - return []string{new} - } - - var ret []string - for _, o := range old { - c, err := compareStates(o, new) - if err != nil { - return []string{new} - } - switch c { - case 1: - ret = append(ret, o) - case -1: - ret = append(ret, new) - case 0: - ret = append(ret, o, new) - } - } - return strutil.RemoveDuplicates(ret, false) -} - func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { client, err := ap.client.Clone() if err != nil { @@ -184,7 +130,7 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, // Vault uses to do so. So instead we compare any of the 0-2 saved states // we have to the new header, keeping the newest 1-2 of these, and sending // them to Vault to evaluate. - ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState) + ap.lastIndexStates = api.MergeReplicationStates(ap.lastIndexStates, newState) ap.l.Unlock() } diff --git a/command/agent/cache/api_proxy_test.go b/command/agent/cache/api_proxy_test.go index 7ce833d82..b90f579c3 100644 --- a/command/agent/cache/api_proxy_test.go +++ b/command/agent/cache/api_proxy_test.go @@ -1,12 +1,9 @@ package cache import ( - "encoding/base64" "net/http" "testing" - "github.com/go-test/deep" - "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/namespace" @@ -96,85 +93,3 @@ func TestAPIProxy_queryParams(t *testing.T) { t.Fatalf("exptected standby to return 200, got: %v", resp.Response.StatusCode) } } - -func TestMergeStates(t *testing.T) { - type testCase struct { - name string - old []string - new string - expected []string - } - - testCases := []testCase{ - { - name: "empty-old", - old: nil, - new: "v1:cid:1:0:", - expected: []string{"v1:cid:1:0:"}, - }, - { - name: "old-smaller", - old: []string{"v1:cid:1:0:"}, - new: "v1:cid:2:0:", - expected: []string{"v1:cid:2:0:"}, - }, - { - name: "old-bigger", - old: []string{"v1:cid:2:0:"}, - new: "v1:cid:1:0:", - expected: []string{"v1:cid:2:0:"}, - }, - { - name: "mixed-single", - old: []string{"v1:cid:1:0:"}, - new: "v1:cid:0:1:", - expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, - }, - { - name: "mixed-single-alt", - old: []string{"v1:cid:0:1:"}, - new: "v1:cid:1:0:", - expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, - }, - { - name: "mixed-double", - old: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, - new: "v1:cid:2:0:", - expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"}, - }, - { - name: "newer-both", - old: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, - new: "v1:cid:2:1:", - expected: []string{"v1:cid:2:1:"}, - }, - } - - b64enc := func(ss []string) []string { - var ret []string - for _, s := range ss { - ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s))) - } - return ret - } - b64dec := func(ss []string) []string { - var ret []string - for _, s := range ss { - d, err := base64.StdEncoding.DecodeString(s) - if err != nil { - t.Fatal(err) - } - ret = append(ret, string(d)) - } - return ret - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - out := b64dec(mergeStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new)))) - if diff := deep.Equal(out, tc.expected); len(diff) != 0 { - t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff) - } - }) - } -} diff --git a/vault/core.go b/vault/core.go index 49ca47ed3..7b20447f0 100644 --- a/vault/core.go +++ b/vault/core.go @@ -3,14 +3,10 @@ package vault import ( "context" "crypto/ecdsa" - "crypto/hmac" "crypto/rand" - "crypto/sha256" "crypto/subtle" "crypto/tls" "crypto/x509" - "encoding/base64" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -20,7 +16,6 @@ import ( "net/url" "os" "path/filepath" - "strconv" "strings" "sync" "sync/atomic" @@ -2824,51 +2819,6 @@ func (c *Core) isPrimary() bool { return !c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary | consts.ReplicationDRSecondary) } -func ParseRequiredState(raw string, hmacKey []byte) (*logical.WALState, error) { - cooked, err := base64.StdEncoding.DecodeString(raw) - if err != nil { - return nil, err - } - s := string(cooked) - - lastIndex := strings.LastIndexByte(s, ':') - if lastIndex == -1 { - return nil, fmt.Errorf("invalid state header format") - } - state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:] - stateHMAC, err := hex.DecodeString(stateHMACRaw) - if err != nil { - return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err) - } - - if len(hmacKey) != 0 { - hm := hmac.New(sha256.New, hmacKey) - hm.Write([]byte(state)) - if !hmac.Equal(hm.Sum(nil), stateHMAC) { - return nil, fmt.Errorf("invalid state header HMAC (mismatch)") - } - } - - pieces := strings.Split(state, ":") - if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" { - return nil, fmt.Errorf("invalid state header format") - } - localIndex, err := strconv.ParseUint(pieces[2], 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid state header format") - } - replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid state header format") - } - - return &logical.WALState{ - ClusterID: pieces[1], - LocalIndex: localIndex, - ReplicatedIndex: replicatedIndex, - }, nil -} - type LicenseState struct { State string ExpiryTime time.Time