open-vault/vault/request_forwarding_rpc.go

194 lines
5.5 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"context"
"net/http"
"os"
"runtime/debug"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/vault/helper/forwarding"
"github.com/hashicorp/vault/physical/raft"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/vault/replication"
)
type forwardedRequestRPCServer struct {
UnimplementedRequestForwardingServer
core *Core
handler http.Handler
perfStandbySlots chan struct{}
perfStandbyRepCluster *replication.Cluster
raftFollowerStates *raft.FollowerStates
}
func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *forwarding.Request) (*forwarding.Response, error) {
// Parse an http.Request out of it
req, err := forwarding.ParseForwardedRequest(freq)
if err != nil {
return nil, err
}
// A very dummy response writer that doesn't follow normal semantics, just
// lets you write a status code (last written wins) and a body. But it
// meets the interface requirements.
w := forwarding.NewRPCResponseWriter()
resp := &forwarding.Response{}
runRequest := func() {
defer func() {
if err := recover(); err != nil {
s.core.logger.Error("panic serving forwarded request", "path", req.URL.Path, "error", err, "stacktrace", string(debug.Stack()))
}
}()
s.handler.ServeHTTP(w, req)
}
runRequest()
resp.StatusCode = uint32(w.StatusCode())
resp.Body = w.Body().Bytes()
header := w.Header()
if header != nil {
resp.HeaderEntries = make(map[string]*forwarding.HeaderEntry, len(header))
for k, v := range header {
resp.HeaderEntries[k] = &forwarding.HeaderEntry{
Values: v,
}
}
}
// Performance standby nodes will use this value to do wait for WALs to ship
// in order to do a best-effort read after write guarantee
resp.LastRemoteWal = LastWAL(s.core)
return resp, nil
}
type nodeHAConnectionInfo struct {
nodeInfo *NodeInformation
lastHeartbeat time.Time
version string
upgradeVersion string
redundancyZone string
}
func (s *forwardedRequestRPCServer) Echo(ctx context.Context, in *EchoRequest) (*EchoReply, error) {
incomingNodeConnectionInfo := nodeHAConnectionInfo{
nodeInfo: in.NodeInfo,
lastHeartbeat: time.Now(),
version: in.SdkVersion,
upgradeVersion: in.RaftUpgradeVersion,
redundancyZone: in.RaftRedundancyZone,
}
if in.ClusterAddr != "" {
s.core.clusterPeerClusterAddrsCache.Set(in.ClusterAddr, incomingNodeConnectionInfo, 0)
}
if in.RaftAppliedIndex > 0 && len(in.RaftNodeID) > 0 && s.raftFollowerStates != nil {
s.raftFollowerStates.Update(&raft.EchoRequestUpdate{
NodeID: in.RaftNodeID,
AppliedIndex: in.RaftAppliedIndex,
Term: in.RaftTerm,
DesiredSuffrage: in.RaftDesiredSuffrage,
SDKVersion: in.SdkVersion,
UpgradeVersion: in.RaftUpgradeVersion,
RedundancyZone: in.RaftRedundancyZone,
})
}
reply := &EchoReply{
Message: "pong",
ReplicationState: uint32(s.core.ReplicationState()),
}
if raftBackend := s.core.getRaftBackend(); raftBackend != nil {
reply.RaftAppliedIndex = raftBackend.AppliedIndex()
reply.RaftNodeID = raftBackend.NodeID()
}
return reply, nil
}
type forwardingClient struct {
RequestForwardingClient
core *Core
echoTicker *time.Ticker
echoContext context.Context
}
// NOTE: we also take advantage of gRPC's keepalive bits, but as we send data
// with these requests it's useful to keep this as well
func (c *forwardingClient) startHeartbeat() {
go func() {
clusterAddr := c.core.ClusterAddr()
hostname, _ := os.Hostname()
ni := NodeInformation{
ApiAddr: c.core.redirectAddr,
Hostname: hostname,
Mode: "standby",
}
tick := func() {
labels := make([]metrics.Label, 0, 1)
defer metrics.MeasureSinceWithLabels([]string{"ha", "rpc", "client", "echo"}, time.Now(), labels)
req := &EchoRequest{
Message: "ping",
ClusterAddr: clusterAddr,
NodeInfo: &ni,
SdkVersion: c.core.effectiveSDKVersion,
}
if raftBackend := c.core.getRaftBackend(); raftBackend != nil {
req.RaftAppliedIndex = raftBackend.AppliedIndex()
req.RaftNodeID = raftBackend.NodeID()
req.RaftTerm = raftBackend.Term()
req.RaftDesiredSuffrage = raftBackend.DesiredSuffrage()
req.RaftRedundancyZone = raftBackend.RedundancyZone()
req.RaftUpgradeVersion = raftBackend.EffectiveVersion()
labels = append(labels, metrics.Label{Name: "peer_id", Value: raftBackend.NodeID()})
}
ctx, cancel := context.WithTimeout(c.echoContext, 2*time.Second)
resp, err := c.RequestForwardingClient.Echo(ctx, req)
cancel()
if err != nil {
metrics.IncrCounter([]string{"ha", "rpc", "client", "echo", "errors"}, 1)
c.core.logger.Debug("forwarding: error sending echo request to active node", "error", err)
return
}
if resp == nil {
c.core.logger.Debug("forwarding: empty echo response from active node")
return
}
if resp.Message != "pong" {
c.core.logger.Debug("forwarding: unexpected echo response from active node", "message", resp.Message)
return
}
// Store the active node's replication state to display in
// sys/health calls
atomic.StoreUint32(c.core.activeNodeReplicationState, resp.ReplicationState)
}
tick()
for {
select {
case <-c.echoContext.Done():
c.echoTicker.Stop()
c.core.logger.Debug("forwarding: stopping heartbeating")
atomic.StoreUint32(c.core.activeNodeReplicationState, uint32(consts.ReplicationUnknown))
return
case <-c.echoTicker.C:
tick()
}
}
}()
}