parent
122f79b3c1
commit
37320f8798
|
@ -2,6 +2,11 @@
|
|||
|
||||
DEPRECATIONS/BREAKING CHANGES:
|
||||
|
||||
* Once the active node is 0.6.1, standby nodes must also be 0.6.1 in order to
|
||||
connect to the HA cluster. We recommend following our [general upgrade
|
||||
instructions](https://www.vaultproject.io/docs/install/upgrade.html) in
|
||||
addition to 0.6.1-specific upgrade instructions to ensure that this is not
|
||||
an issue.
|
||||
* Root tokens (tokens with the `root` policy) can no longer be created except
|
||||
by another root token or the `generate-root` endpoint.
|
||||
* Issued certificates from the `pki` backend against new roles created or
|
||||
|
@ -104,7 +109,7 @@ IMPROVEMENTS:
|
|||
[GH-1699]
|
||||
* physical/etcd: Support `ETCD_ADDR` env var for specifying addresses [GH-1576]
|
||||
* physical/consul: Allowing additional tags to be added to Consul service
|
||||
registration via `service-tags` option [GH-1643]
|
||||
registration via `service_tags` option [GH-1643]
|
||||
* secret/aws: Listing of roles is supported now [GH-1546]
|
||||
* secret/cassandra: Add `connect_timeout` value for Cassandra connection
|
||||
configuration [GH-1581]
|
||||
|
|
|
@ -54,7 +54,7 @@ type ServerCommand struct {
|
|||
}
|
||||
|
||||
func (c *ServerCommand) Run(args []string) int {
|
||||
var dev, verifyOnly bool
|
||||
var dev, verifyOnly, devHA bool
|
||||
var configPath []string
|
||||
var logLevel, devRootTokenID, devListenAddress string
|
||||
flags := c.Meta.FlagSet("server", meta.FlagSetDefault)
|
||||
|
@ -63,6 +63,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
flags.StringVar(&devListenAddress, "dev-listen-address", "", "")
|
||||
flags.StringVar(&logLevel, "log-level", "info", "")
|
||||
flags.BoolVar(&verifyOnly, "verify-only", false, "")
|
||||
flags.BoolVar(&devHA, "dev-ha", false, "")
|
||||
flags.Usage = func() { c.Ui.Error(c.Help()) }
|
||||
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
|
||||
if err := flags.Parse(args); err != nil {
|
||||
|
@ -98,7 +99,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
// Load the configuration
|
||||
var config *server.Config
|
||||
if dev {
|
||||
config = server.DevConfig()
|
||||
config = server.DevConfig(devHA)
|
||||
if devListenAddress != "" {
|
||||
config.Listeners[0].Config["address"] = devListenAddress
|
||||
}
|
||||
|
@ -179,7 +180,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
Physical: backend,
|
||||
AdvertiseAddr: config.Backend.AdvertiseAddr,
|
||||
RedirectAddr: config.Backend.RedirectAddr,
|
||||
HAPhysical: nil,
|
||||
Seal: seal,
|
||||
AuditBackends: c.AuditBackends,
|
||||
|
@ -193,6 +194,8 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
ClusterName: config.ClusterName,
|
||||
}
|
||||
|
||||
var disableClustering bool
|
||||
|
||||
// Initialize the separate HA physical backend, if it exists
|
||||
var ok bool
|
||||
if config.HABackend != nil {
|
||||
|
@ -215,35 +218,85 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
coreConfig.AdvertiseAddr = config.HABackend.AdvertiseAddr
|
||||
coreConfig.RedirectAddr = config.HABackend.RedirectAddr
|
||||
disableClustering = config.HABackend.DisableClustering
|
||||
if !disableClustering {
|
||||
coreConfig.ClusterAddr = config.HABackend.ClusterAddr
|
||||
}
|
||||
} else {
|
||||
if coreConfig.HAPhysical, ok = backend.(physical.HABackend); ok {
|
||||
coreConfig.AdvertiseAddr = config.Backend.AdvertiseAddr
|
||||
coreConfig.RedirectAddr = config.Backend.RedirectAddr
|
||||
disableClustering = config.Backend.DisableClustering
|
||||
if !disableClustering {
|
||||
coreConfig.ClusterAddr = config.Backend.ClusterAddr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if envAA := os.Getenv("VAULT_ADVERTISE_ADDR"); envAA != "" {
|
||||
coreConfig.AdvertiseAddr = envAA
|
||||
if envRA := os.Getenv("VAULT_REDIRECT_ADDR"); envRA != "" {
|
||||
coreConfig.RedirectAddr = envRA
|
||||
} else if envAA := os.Getenv("VAULT_ADVERTISE_ADDR"); envAA != "" {
|
||||
coreConfig.RedirectAddr = envAA
|
||||
}
|
||||
|
||||
// Attempt to detect the advertise address, if possible
|
||||
var detect physical.AdvertiseDetect
|
||||
// Attempt to detect the redirect address, if possible
|
||||
var detect physical.RedirectDetect
|
||||
if coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled() {
|
||||
detect, ok = coreConfig.HAPhysical.(physical.AdvertiseDetect)
|
||||
detect, ok = coreConfig.HAPhysical.(physical.RedirectDetect)
|
||||
} else {
|
||||
detect, ok = coreConfig.Physical.(physical.AdvertiseDetect)
|
||||
detect, ok = coreConfig.Physical.(physical.RedirectDetect)
|
||||
}
|
||||
if ok && coreConfig.AdvertiseAddr == "" {
|
||||
advertise, err := c.detectAdvertise(detect, config)
|
||||
if ok && coreConfig.RedirectAddr == "" {
|
||||
redirect, err := c.detectRedirect(detect, config)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error detecting advertise address: %s", err))
|
||||
} else if advertise == "" {
|
||||
c.Ui.Error("Failed to detect advertise address.")
|
||||
c.Ui.Error(fmt.Sprintf("Error detecting redirect address: %s", err))
|
||||
} else if redirect == "" {
|
||||
c.Ui.Error("Failed to detect redirect address.")
|
||||
} else {
|
||||
coreConfig.AdvertiseAddr = advertise
|
||||
coreConfig.RedirectAddr = redirect
|
||||
}
|
||||
}
|
||||
|
||||
// After the redirect bits are sorted out, if no cluster address was
|
||||
// explicitly given, derive one from the redirect addr
|
||||
if disableClustering {
|
||||
coreConfig.ClusterAddr = ""
|
||||
} else if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" {
|
||||
coreConfig.ClusterAddr = envCA
|
||||
} else if coreConfig.ClusterAddr == "" && coreConfig.RedirectAddr != "" {
|
||||
u, err := url.ParseRequestURI(coreConfig.RedirectAddr)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error parsing redirect address %s: %v", coreConfig.RedirectAddr, err))
|
||||
return 1
|
||||
}
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
nPort, nPortErr := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
// assume it's due to there not being a port specified, in which case
|
||||
// use 443
|
||||
host = u.Host
|
||||
nPort = 443
|
||||
}
|
||||
if nPortErr != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Cannot parse %s as a numeric port: %v", port, nPortErr))
|
||||
return 1
|
||||
}
|
||||
u.Host = net.JoinHostPort(host, strconv.Itoa(nPort+1))
|
||||
// Will always be TLS-secured
|
||||
u.Scheme = "https"
|
||||
coreConfig.ClusterAddr = u.String()
|
||||
}
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
// Force https as we'll always be TLS-secured
|
||||
u, err := url.ParseRequestURI(coreConfig.ClusterAddr)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err))
|
||||
return 1
|
||||
}
|
||||
u.Scheme = "https"
|
||||
coreConfig.ClusterAddr = u.String()
|
||||
}
|
||||
|
||||
// Initialize the core
|
||||
core, newCoreError := vault.NewCore(coreConfig)
|
||||
if newCoreError != nil {
|
||||
|
@ -253,39 +306,6 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
}
|
||||
}
|
||||
|
||||
// If we're in dev mode, then initialize the core
|
||||
if dev {
|
||||
init, err := c.enableDev(core, devRootTokenID)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing dev mode: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
export := "export"
|
||||
quote := "'"
|
||||
if runtime.GOOS == "windows" {
|
||||
export = "set"
|
||||
quote = ""
|
||||
}
|
||||
|
||||
c.Ui.Output(fmt.Sprintf(
|
||||
"==> WARNING: Dev mode is enabled!\n\n"+
|
||||
"In this mode, Vault is completely in-memory and unsealed.\n"+
|
||||
"Vault is configured to only have a single unseal key. The root\n"+
|
||||
"token has already been authenticated with the CLI, so you can\n"+
|
||||
"immediately begin using the Vault CLI.\n\n"+
|
||||
"The only step you need to take is to set the following\n"+
|
||||
"environment variables:\n\n"+
|
||||
" "+export+" VAULT_ADDR="+quote+"http://"+config.Listeners[0].Config["address"]+quote+"\n\n"+
|
||||
"The unseal key and root token are reproduced below in case you\n"+
|
||||
"want to seal/unseal the Vault or play with authentication.\n\n"+
|
||||
"Unseal Key: %s\nRoot Token: %s\n",
|
||||
hex.EncodeToString(init.SecretShares[0]),
|
||||
init.RootToken,
|
||||
))
|
||||
}
|
||||
|
||||
// Compile server information for output later
|
||||
info["backend"] = config.Backend.Type
|
||||
info["log level"] = logLevel
|
||||
|
@ -296,21 +316,31 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
if config.HABackend != nil {
|
||||
info["HA backend"] = config.HABackend.Type
|
||||
info["advertise address"] = coreConfig.AdvertiseAddr
|
||||
infoKeys = append(infoKeys, "HA backend", "advertise address")
|
||||
info["redirect address"] = coreConfig.RedirectAddr
|
||||
infoKeys = append(infoKeys, "HA backend", "redirect address")
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
info["cluster address"] = coreConfig.ClusterAddr
|
||||
infoKeys = append(infoKeys, "cluster address")
|
||||
}
|
||||
} else {
|
||||
// If the backend supports HA, then note it
|
||||
if coreConfig.HAPhysical != nil {
|
||||
if coreConfig.HAPhysical.HAEnabled() {
|
||||
info["backend"] += " (HA available)"
|
||||
info["advertise address"] = coreConfig.AdvertiseAddr
|
||||
infoKeys = append(infoKeys, "advertise address")
|
||||
info["redirect address"] = coreConfig.RedirectAddr
|
||||
infoKeys = append(infoKeys, "redirect address")
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
info["cluster address"] = coreConfig.ClusterAddr
|
||||
infoKeys = append(infoKeys, "cluster address")
|
||||
}
|
||||
} else {
|
||||
info["backend"] += " (HA disabled)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusterAddrs := []string{}
|
||||
|
||||
// Initialize the listeners
|
||||
lns := make([]net.Listener, 0, len(config.Listeners))
|
||||
for i, lnConfig := range config.Listeners {
|
||||
|
@ -322,6 +352,35 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
lns = append(lns, ln)
|
||||
|
||||
if reloadFunc != nil {
|
||||
relSlice := c.ReloadFuncs["listener|"+lnConfig.Type]
|
||||
relSlice = append(relSlice, reloadFunc)
|
||||
c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice
|
||||
}
|
||||
|
||||
if !disableClustering && lnConfig.Type == "tcp" {
|
||||
var addr string
|
||||
var ok bool
|
||||
if addr, ok = lnConfig.Config["cluster_address"]; ok {
|
||||
clusterAddrs = append(clusterAddrs, addr)
|
||||
} else {
|
||||
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
c.Ui.Error("Failed to parse tcp listener")
|
||||
return 1
|
||||
}
|
||||
ipStr := tcpAddr.IP.String()
|
||||
if len(tcpAddr.IP) == net.IPv6len {
|
||||
ipStr = fmt.Sprintf("[%s]", ipStr)
|
||||
}
|
||||
addr = fmt.Sprintf("%s:%d", ipStr, tcpAddr.Port+1)
|
||||
clusterAddrs = append(clusterAddrs, addr)
|
||||
}
|
||||
props["cluster address"] = addr
|
||||
}
|
||||
|
||||
// Store the listener props for output later
|
||||
key := fmt.Sprintf("listener %d", i+1)
|
||||
propsList := make([]string, 0, len(props))
|
||||
|
@ -334,13 +393,9 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
info[key] = fmt.Sprintf(
|
||||
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))
|
||||
|
||||
lns = append(lns, ln)
|
||||
|
||||
if reloadFunc != nil {
|
||||
relSlice := c.ReloadFuncs["listener|"+lnConfig.Type]
|
||||
relSlice = append(relSlice, reloadFunc)
|
||||
c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice
|
||||
}
|
||||
}
|
||||
if !disableClustering {
|
||||
c.logger.Printf("[TRACE] cluster listeners will be started on %v", clusterAddrs)
|
||||
}
|
||||
|
||||
// Make sure we close all listeners from this point on
|
||||
|
@ -394,16 +449,55 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return true
|
||||
}
|
||||
|
||||
if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.AdvertiseAddr, activeFunc, sealedFunc); err != nil {
|
||||
if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, sealedFunc); err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handler := vaulthttp.Handler(core)
|
||||
|
||||
// This needs to happen before we first unseal, so before we trigger dev
|
||||
// mode if it's set
|
||||
core.SetClusterListenerSetupFunc(vault.WrapListenersForClustering(clusterAddrs, handler, c.logger))
|
||||
|
||||
// If we're in dev mode, then initialize the core
|
||||
if dev {
|
||||
init, err := c.enableDev(core, devRootTokenID)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Error initializing dev mode: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
export := "export"
|
||||
quote := "'"
|
||||
if runtime.GOOS == "windows" {
|
||||
export = "set"
|
||||
quote = ""
|
||||
}
|
||||
|
||||
c.Ui.Output(fmt.Sprintf(
|
||||
"==> WARNING: Dev mode is enabled!\n\n"+
|
||||
"In this mode, Vault is completely in-memory and unsealed.\n"+
|
||||
"Vault is configured to only have a single unseal key. The root\n"+
|
||||
"token has already been authenticated with the CLI, so you can\n"+
|
||||
"immediately begin using the Vault CLI.\n\n"+
|
||||
"The only step you need to take is to set the following\n"+
|
||||
"environment variables:\n\n"+
|
||||
" "+export+" VAULT_ADDR="+quote+"http://"+config.Listeners[0].Config["address"]+quote+"\n\n"+
|
||||
"The unseal key and root token are reproduced below in case you\n"+
|
||||
"want to seal/unseal the Vault or play with authentication.\n\n"+
|
||||
"Unseal Key: %s\nRoot Token: %s\n",
|
||||
hex.EncodeToString(init.SecretShares[0]),
|
||||
init.RootToken,
|
||||
))
|
||||
}
|
||||
|
||||
// Initialize the HTTP server
|
||||
server := &http.Server{}
|
||||
server.Handler = vaulthttp.Handler(core)
|
||||
server.Handler = handler
|
||||
for _, ln := range lns {
|
||||
go server.Serve(ln)
|
||||
}
|
||||
|
@ -466,6 +560,27 @@ func (c *ServerCommand) enableDev(core *vault.Core, rootTokenID string) (*vault.
|
|||
return nil, fmt.Errorf("failed to unseal Vault for dev mode")
|
||||
}
|
||||
|
||||
isLeader, _, err := core.Leader()
|
||||
if err != nil && err != vault.ErrHANotEnabled {
|
||||
return nil, fmt.Errorf("failed to check active status: %v", err)
|
||||
}
|
||||
if err == nil {
|
||||
leaderCount := 5
|
||||
for !isLeader {
|
||||
if leaderCount == 0 {
|
||||
buf := make([]byte, 1<<16)
|
||||
runtime.Stack(buf, true)
|
||||
return nil, fmt.Errorf("failed to get active status after five seconds; call stack is\n%s\n", buf)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
isLeader, _, err = core.Leader()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check active status: %v", err)
|
||||
}
|
||||
leaderCount--
|
||||
}
|
||||
}
|
||||
|
||||
if rootTokenID != "" {
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
|
@ -511,8 +626,8 @@ func (c *ServerCommand) enableDev(core *vault.Core, rootTokenID string) (*vault.
|
|||
return init, nil
|
||||
}
|
||||
|
||||
// detectAdvertise is used to attempt advertise address detection
|
||||
func (c *ServerCommand) detectAdvertise(detect physical.AdvertiseDetect,
|
||||
// detectRedirect is used to attempt redirect address detection
|
||||
func (c *ServerCommand) detectRedirect(detect physical.RedirectDetect,
|
||||
config *server.Config) (string, error) {
|
||||
// Get the hostname
|
||||
host, err := detect.DetectHostAddr()
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -38,8 +39,8 @@ type Config struct {
|
|||
}
|
||||
|
||||
// DevConfig is a Config that is used for dev mode of Vault.
|
||||
func DevConfig() *Config {
|
||||
return &Config{
|
||||
func DevConfig(ha bool) *Config {
|
||||
ret := &Config{
|
||||
DisableCache: false,
|
||||
DisableMlock: true,
|
||||
|
||||
|
@ -62,6 +63,12 @@ func DevConfig() *Config {
|
|||
MaxLeaseTTL: 30 * 24 * time.Hour,
|
||||
DefaultLeaseTTL: 30 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
if ha {
|
||||
ret.Backend.Type = "inmem_ha"
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Listener is the listener configuration for the server.
|
||||
|
@ -76,9 +83,11 @@ func (l *Listener) GoString() string {
|
|||
|
||||
// Backend is the backend configuration for the server.
|
||||
type Backend struct {
|
||||
Type string
|
||||
AdvertiseAddr string
|
||||
Config map[string]string
|
||||
Type string
|
||||
RedirectAddr string
|
||||
ClusterAddr string
|
||||
DisableClustering bool
|
||||
Config map[string]string
|
||||
}
|
||||
|
||||
func (b *Backend) GoString() string {
|
||||
|
@ -442,17 +451,40 @@ func parseBackends(result *Config, list *ast.ObjectList) error {
|
|||
return multierror.Prefix(err, fmt.Sprintf("backend.%s:", key))
|
||||
}
|
||||
|
||||
// Pull out the advertise address since it's common to all backends
|
||||
var advertiseAddr string
|
||||
if v, ok := m["advertise_addr"]; ok {
|
||||
advertiseAddr = v
|
||||
// Pull out the redirect address since it's common to all backends
|
||||
var redirectAddr string
|
||||
if v, ok := m["redirect_addr"]; ok {
|
||||
redirectAddr = v
|
||||
delete(m, "redirect_addr")
|
||||
} else if v, ok := m["advertise_addr"]; ok {
|
||||
redirectAddr = v
|
||||
delete(m, "advertise_addr")
|
||||
}
|
||||
|
||||
// Pull out the cluster address since it's common to all backends
|
||||
var clusterAddr string
|
||||
if v, ok := m["cluster_addr"]; ok {
|
||||
clusterAddr = v
|
||||
delete(m, "cluster_addr")
|
||||
}
|
||||
|
||||
//TODO: Change this in the future
|
||||
disableClustering := true
|
||||
var err error
|
||||
if v, ok := m["disable_clustering"]; ok {
|
||||
disableClustering, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return multierror.Prefix(err, fmt.Sprintf("backend.%s:", key))
|
||||
}
|
||||
delete(m, "disable_clustering")
|
||||
}
|
||||
|
||||
result.Backend = &Backend{
|
||||
AdvertiseAddr: advertiseAddr,
|
||||
Type: strings.ToLower(key),
|
||||
Config: m,
|
||||
RedirectAddr: redirectAddr,
|
||||
ClusterAddr: clusterAddr,
|
||||
DisableClustering: disableClustering,
|
||||
Type: strings.ToLower(key),
|
||||
Config: m,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -475,17 +507,40 @@ func parseHABackends(result *Config, list *ast.ObjectList) error {
|
|||
return multierror.Prefix(err, fmt.Sprintf("ha_backend.%s:", key))
|
||||
}
|
||||
|
||||
// Pull out the advertise address since it's common to all backends
|
||||
var advertiseAddr string
|
||||
if v, ok := m["advertise_addr"]; ok {
|
||||
advertiseAddr = v
|
||||
// Pull out the redirect address since it's common to all backends
|
||||
var redirectAddr string
|
||||
if v, ok := m["redirect_addr"]; ok {
|
||||
redirectAddr = v
|
||||
delete(m, "redirect_addr")
|
||||
} else if v, ok := m["advertise_addr"]; ok {
|
||||
redirectAddr = v
|
||||
delete(m, "advertise_addr")
|
||||
}
|
||||
|
||||
// Pull out the cluster address since it's common to all backends
|
||||
var clusterAddr string
|
||||
if v, ok := m["cluster_addr"]; ok {
|
||||
clusterAddr = v
|
||||
delete(m, "cluster_addr")
|
||||
}
|
||||
|
||||
//TODO: Change this in the future
|
||||
disableClustering := true
|
||||
var err error
|
||||
if v, ok := m["disable_clustering"]; ok {
|
||||
disableClustering, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return multierror.Prefix(err, fmt.Sprintf("backend.%s:", key))
|
||||
}
|
||||
delete(m, "disable_clustering")
|
||||
}
|
||||
|
||||
result.HABackend = &Backend{
|
||||
AdvertiseAddr: advertiseAddr,
|
||||
Type: strings.ToLower(key),
|
||||
Config: m,
|
||||
RedirectAddr: redirectAddr,
|
||||
ClusterAddr: clusterAddr,
|
||||
DisableClustering: disableClustering,
|
||||
Type: strings.ToLower(key),
|
||||
Config: m,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -502,6 +557,7 @@ func parseListeners(result *Config, list *ast.ObjectList) error {
|
|||
|
||||
valid := []string{
|
||||
"address",
|
||||
"cluster_address",
|
||||
"endpoint",
|
||||
"infrastructure",
|
||||
"node_id",
|
||||
|
|
|
@ -33,16 +33,17 @@ func TestLoadConfigFile(t *testing.T) {
|
|||
},
|
||||
|
||||
Backend: &Backend{
|
||||
Type: "consul",
|
||||
AdvertiseAddr: "foo",
|
||||
Type: "consul",
|
||||
RedirectAddr: "foo",
|
||||
Config: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
DisableClustering: true,
|
||||
},
|
||||
|
||||
HABackend: &Backend{
|
||||
Type: "consul",
|
||||
AdvertiseAddr: "snafu",
|
||||
Type: "consul",
|
||||
RedirectAddr: "snafu",
|
||||
Config: map[string]string{
|
||||
"bar": "baz",
|
||||
},
|
||||
|
@ -98,6 +99,7 @@ func TestLoadConfigFile_json(t *testing.T) {
|
|||
Config: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
DisableClustering: true,
|
||||
},
|
||||
|
||||
Telemetry: &Telemetry{
|
||||
|
@ -155,6 +157,7 @@ func TestLoadConfigFile_json2(t *testing.T) {
|
|||
Config: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
DisableClustering: true,
|
||||
},
|
||||
|
||||
HABackend: &Backend{
|
||||
|
@ -182,7 +185,6 @@ func TestLoadConfigFile_json2(t *testing.T) {
|
|||
},
|
||||
}
|
||||
if !reflect.DeepEqual(config, expected) {
|
||||
t.Fatalf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config, expected)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -223,7 +225,7 @@ func TestLoadConfigDir(t *testing.T) {
|
|||
ClusterName: "testcluster",
|
||||
}
|
||||
if !reflect.DeepEqual(config, expected) {
|
||||
t.Fatalf("bad: %#v", config)
|
||||
t.Fatalf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config, expected)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,4 +3,5 @@ disable_mlock = true
|
|||
|
||||
backend "consul" {
|
||||
foo = "bar"
|
||||
disable_clustering = "false"
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ backend "consul" {
|
|||
ha_backend "consul" {
|
||||
bar = "baz"
|
||||
advertise_addr = "snafu"
|
||||
disable_clustering = "false"
|
||||
}
|
||||
|
||||
max_lease_ttl = "10h"
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
},
|
||||
"ha_backend":{
|
||||
"consul":{
|
||||
"bar":"baz"
|
||||
"bar":"baz",
|
||||
"disable_clustering": "false"
|
||||
}
|
||||
},
|
||||
"telemetry":{
|
||||
|
|
|
@ -39,7 +39,7 @@ backend "consul" {
|
|||
haconsulhcl = `
|
||||
ha_backend "consul" {
|
||||
prefix = "bar/"
|
||||
advertise_addr = "http://127.0.0.1:8200"
|
||||
redirect_addr = "http://127.0.0.1:8200"
|
||||
disable_registration = "true"
|
||||
}
|
||||
`
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
package requestutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/hashicorp/vault/helper/compressutil"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
)
|
||||
|
||||
type bufCloser struct {
|
||||
*bytes.Buffer
|
||||
}
|
||||
|
||||
func (b bufCloser) Close() error {
|
||||
b.Reset()
|
||||
return nil
|
||||
}
|
||||
|
||||
type ForwardedRequest struct {
|
||||
// The original method
|
||||
Method string `json:"method"`
|
||||
|
||||
// The original URL object
|
||||
URL *url.URL `json:"url"`
|
||||
|
||||
// The original headers
|
||||
Header http.Header `json:"header"`
|
||||
|
||||
// The request body
|
||||
Body []byte `json:"body"`
|
||||
|
||||
// The specified host
|
||||
Host string `json:"host"`
|
||||
|
||||
// The remote address
|
||||
RemoteAddr string `json:"remote_addr"`
|
||||
|
||||
// The client's TLS connection state
|
||||
ConnectionState *tls.ConnectionState `json:"connection_state"`
|
||||
}
|
||||
|
||||
// GenerateForwardedRequest generates a new http.Request that contains the
|
||||
// original requests's information in the new request's body.
|
||||
func GenerateForwardedRequest(req *http.Request, addr string) (*http.Request, error) {
|
||||
fq := ForwardedRequest{
|
||||
Method: req.Method,
|
||||
URL: req.URL,
|
||||
Header: req.Header,
|
||||
Host: req.Host,
|
||||
RemoteAddr: req.RemoteAddr,
|
||||
ConnectionState: req.TLS,
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
_, err := buf.ReadFrom(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fq.Body = buf.Bytes()
|
||||
|
||||
newBody, err := jsonutil.EncodeJSONAndCompress(&fq, &compressutil.CompressionConfig{
|
||||
Type: compressutil.CompressionTypeLzw,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret, err := http.NewRequest("POST", addr, bytes.NewBuffer(newBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// ParseForwardedRequest generates a new http.Request that is comprised of the
|
||||
// values in the given request's body, assuming it correctly parses into a
|
||||
// ForwardedRequest.
|
||||
func ParseForwardedRequest(req *http.Request) (*http.Request, error) {
|
||||
buf := bufCloser{
|
||||
Buffer: bytes.NewBuffer(nil),
|
||||
}
|
||||
_, err := buf.ReadFrom(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fq ForwardedRequest
|
||||
err = jsonutil.DecodeJSON(buf.Bytes(), &fq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
_, err = buf.Write(fq.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &http.Request{
|
||||
Method: fq.Method,
|
||||
URL: fq.URL,
|
||||
Header: fq.Header,
|
||||
Body: buf,
|
||||
Host: fq.Host,
|
||||
RemoteAddr: fq.RemoteAddr,
|
||||
TLS: fq.ConnectionState,
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
package requestutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestForwardedRequestGenerateParse(t *testing.T) {
|
||||
bodBuf := bytes.NewReader([]byte(`{ "foo": "bar", "zip": { "argle": "bargle", neet: 0 } }`))
|
||||
req, err := http.NewRequest("FOOBAR", "https://pushit.real.good:9281/snicketysnack?furbleburble=bloopetybloop", bodBuf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.TLS = &tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
HandshakeComplete: true,
|
||||
ServerName: "tralala",
|
||||
}
|
||||
|
||||
// We want to get the fields we would expect from an incoming request, so
|
||||
// we write it out and then read it again
|
||||
buf1 := bytes.NewBuffer(nil)
|
||||
err = req.Write(buf1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read it back in, parsing like a server
|
||||
bufr1 := bufio.NewReader(buf1)
|
||||
initialReq, err := http.ReadRequest(bufr1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate the request with the forwarded request in the body
|
||||
req, err = GenerateForwardedRequest(initialReq, "https://bloopety.bloop:8201")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform another "round trip"
|
||||
buf2 := bytes.NewBuffer(nil)
|
||||
err = req.Write(buf2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bufr2 := bufio.NewReader(buf2)
|
||||
intreq, err := http.ReadRequest(bufr2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Now extract the forwarded request to generate a final request for processing
|
||||
finalReq, err := ParseForwardedRequest(intreq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case initialReq.Method != finalReq.Method:
|
||||
t.Fatalf("bad method:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
|
||||
case initialReq.RemoteAddr != finalReq.RemoteAddr:
|
||||
t.Fatalf("bad remoteaddr:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
|
||||
case initialReq.Host != finalReq.Host:
|
||||
t.Fatalf("bad host:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
|
||||
case !reflect.DeepEqual(initialReq.URL, finalReq.URL):
|
||||
t.Fatalf("bad url:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq.URL, *finalReq.URL)
|
||||
case !reflect.DeepEqual(initialReq.Header, finalReq.Header):
|
||||
t.Fatalf("bad header:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
|
||||
case !reflect.DeepEqual(initialReq.TLS, finalReq.TLS):
|
||||
t.Fatalf("bad tls:\ninitialReq:\n%#v\nfinalReq:\n%#v\n", *initialReq, *finalReq)
|
||||
default:
|
||||
// Compare bodies
|
||||
bodBuf.Seek(0, 0)
|
||||
initBuf := bytes.NewBuffer(nil)
|
||||
_, err = initBuf.ReadFrom(bodBuf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
finBuf := bytes.NewBuffer(nil)
|
||||
_, err = finBuf.ReadFrom(finalReq.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(initBuf.Bytes(), finBuf.Bytes()) {
|
||||
t.Fatalf("badbody :\ninitialReq:\n%#v\nfinalReq:\n%#v\n", initBuf.Bytes(), finBuf.Bytes())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,521 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/api"
|
||||
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestHTTP_Fallback_Bad_Address(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
ClusterAddr: "https://127.3.4.1:8382",
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
addrs := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = addr
|
||||
config.HttpClient = cleanhttp.DefaultClient()
|
||||
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.SetToken(root)
|
||||
|
||||
secret, err := client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil {
|
||||
t.Fatal("secret is nil")
|
||||
}
|
||||
if secret.Data["id"].(string) != root {
|
||||
t.Fatal("token mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP_Fallback_Disabled(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
ClusterAddr: "empty",
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
addrs := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = addr
|
||||
config.HttpClient = cleanhttp.DefaultClient()
|
||||
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.SetToken(root)
|
||||
|
||||
secret, err := client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil {
|
||||
t.Fatal("secret is nil")
|
||||
}
|
||||
if secret.Data["id"].(string) != root {
|
||||
t.Fatal("token mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function recreates the fuzzy testing from transit to pipe a large
|
||||
// number of requests from the standbys to the active node.
|
||||
func TestHTTP_Forwarding_Stress(t *testing.T) {
|
||||
testPlaintext := "the quick brown fox"
|
||||
testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
|
||||
keys := []string{"test1", "test2", "test3"}
|
||||
|
||||
hosts := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
transport := cleanhttp.DefaultPooledTransport()
|
||||
transport.TLSClientConfig = cores[0].TLSConfig
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return fmt.Errorf("redirects not allowed in this test")
|
||||
},
|
||||
}
|
||||
|
||||
//core.Logger().Printf("[TRACE] mounting transit")
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port),
|
||||
bytes.NewBuffer([]byte("{\"type\": \"transit\"}")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//core.Logger().Printf("[TRACE] done mounting transit")
|
||||
|
||||
var totalOps int64
|
||||
var successfulOps int64
|
||||
var key1ver int64 = 1
|
||||
var key2ver int64 = 1
|
||||
var key3ver int64 = 1
|
||||
|
||||
// This is the goroutine loop
|
||||
doFuzzy := func(id int) {
|
||||
// Check for panics, otherwise notify we're done
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
core.Logger().Printf("[ERR] got a panic: %v", err)
|
||||
t.Fail()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Holds the latest encrypted value for each key
|
||||
latestEncryptedText := map[string]string{}
|
||||
|
||||
startTime := time.Now()
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
var chosenFunc, chosenKey, chosenHost string
|
||||
|
||||
doReq := func(method, url string, body io.Reader) (*http.Response, error) {
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
doResp := func(resp *http.Response) (*api.Secret, error) {
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("nil response")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Make sure we weren't redirected
|
||||
if resp.StatusCode > 300 && resp.StatusCode < 400 {
|
||||
return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp)
|
||||
}
|
||||
|
||||
result := &api.Response{Response: resp}
|
||||
err = result.Error()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secret, err := api.ParseSecret(result.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
for _, chosenHost := range hosts {
|
||||
for _, chosenKey := range keys {
|
||||
// Try to write the key to make sure it exists
|
||||
_, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBuffer([]byte("{}")))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//core.Logger().Printf("[TRACE] Starting %d", id)
|
||||
for {
|
||||
// Stop after 10 seconds
|
||||
if time.Now().Sub(startTime) > 10*time.Second {
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt64(&totalOps, 1)
|
||||
|
||||
// Pick a function and a key
|
||||
chosenFunc = funcs[rand.Int()%len(funcs)]
|
||||
chosenKey = keys[rand.Int()%len(keys)]
|
||||
chosenHost = hosts[rand.Int()%len(hosts)]
|
||||
|
||||
switch chosenFunc {
|
||||
// Encrypt our plaintext and store the result
|
||||
case "encrypt":
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
||||
resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64))))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
secret, err := doResp(resp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
latest := secret.Data["ciphertext"].(string)
|
||||
if latest == "" {
|
||||
panic(fmt.Errorf("bad ciphertext"))
|
||||
}
|
||||
latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string)
|
||||
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
|
||||
// Decrypt the ciphertext and compare the result
|
||||
case "decrypt":
|
||||
ct := latestEncryptedText[chosenKey]
|
||||
if ct == "" {
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
||||
resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct))))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
secret, err := doResp(resp)
|
||||
if err != nil {
|
||||
// This could well happen since the min version is jumping around
|
||||
if strings.Contains(err.Error(), transit.ErrTooOld) {
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
continue
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptb64 := secret.Data["plaintext"].(string)
|
||||
pt, err := base64.StdEncoding.DecodeString(ptb64)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err))
|
||||
}
|
||||
if string(pt) != testPlaintext {
|
||||
panic(fmt.Errorf("got bad plaintext back: %s", pt))
|
||||
}
|
||||
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
|
||||
// Rotate to a new key version
|
||||
case "rotate":
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
||||
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBuffer([]byte("{}")))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
switch chosenKey {
|
||||
case "test1":
|
||||
atomic.AddInt64(&key1ver, 1)
|
||||
case "test2":
|
||||
atomic.AddInt64(&key2ver, 1)
|
||||
case "test3":
|
||||
atomic.AddInt64(&key3ver, 1)
|
||||
}
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
|
||||
// Change the min version, which also tests the archive functionality
|
||||
case "change_min_version":
|
||||
var latestVersion int64
|
||||
switch chosenKey {
|
||||
case "test1":
|
||||
latestVersion = atomic.LoadInt64(&key1ver)
|
||||
case "test2":
|
||||
latestVersion = atomic.LoadInt64(&key2ver)
|
||||
case "test3":
|
||||
latestVersion = atomic.LoadInt64(&key3ver)
|
||||
}
|
||||
|
||||
setVersion := (rand.Int63() % latestVersion) + 1
|
||||
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion)
|
||||
|
||||
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBuffer([]byte(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion))))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn 20 of these workers for 10 seconds
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
//core.Logger().Printf("[TRACE] spawning %d", i)
|
||||
go doFuzzy(i)
|
||||
}
|
||||
|
||||
// Wait for them all to finish
|
||||
wg.Wait()
|
||||
|
||||
core.Logger().Printf("[TRACE] total operations tried: %d, total successful: %d", totalOps, successfulOps)
|
||||
if totalOps != successfulOps {
|
||||
t.Fatalf("total/successful ops mismatch: %d/%d", totalOps, successfulOps)
|
||||
}
|
||||
}
|
||||
|
||||
// This tests TLS connection state forwarding by ensuring that we can use a
|
||||
// client TLS to authenticate against the cert backend
|
||||
func TestHTTP_Forwarding_ClientTLS(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"cert": credCert.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
transport := cleanhttp.DefaultTransport()
|
||||
transport.TLSClientConfig = cores[0].TLSConfig
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port),
|
||||
bytes.NewBuffer([]byte("{\"type\": \"cert\"}")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type certConfig struct {
|
||||
Certificate string `json:"certificate"`
|
||||
Policies string `json:"policies"`
|
||||
}
|
||||
encodedCertConfig, err := json.Marshal(&certConfig{
|
||||
Certificate: vault.TestClusterCACert,
|
||||
Policies: "default",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port),
|
||||
bytes.NewBuffer(encodedCertConfig))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addrs := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
// Ensure we can't possibly use lingering connections even though it should be to a different address
|
||||
|
||||
transport = cleanhttp.DefaultTransport()
|
||||
transport.TLSClientConfig = cores[0].TLSConfig
|
||||
|
||||
client = &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return fmt.Errorf("redirects not allowed in this test")
|
||||
},
|
||||
}
|
||||
|
||||
//cores[0].Logger().Printf("root token is %s", root)
|
||||
//time.Sleep(4 * time.Hour)
|
||||
|
||||
for _, addr := range addrs {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = addr
|
||||
config.HttpClient = client
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secret, err := client.Logical().Write("auth/cert/login", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil {
|
||||
t.Fatal("secret is nil")
|
||||
}
|
||||
if secret.Auth == nil {
|
||||
t.Fatal("auth is nil")
|
||||
}
|
||||
if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" {
|
||||
t.Fatalf("bad policies: %#v", secret.Auth.Policies)
|
||||
}
|
||||
}
|
||||
}
|
120
http/handler.go
120
http/handler.go
|
@ -1,6 +1,7 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -22,6 +23,10 @@ const (
|
|||
// WrapHeaderName is the name of the header containing a directive to wrap the
|
||||
// response.
|
||||
WrapTTLHeaderName = "X-Vault-Wrap-TTL"
|
||||
|
||||
// NoRequestForwardingHeaderName is the name of the header telling Vault
|
||||
// not to use request forwarding
|
||||
NoRequestForwardingHeaderName = "X-Vault-No-Request-Forwarding"
|
||||
)
|
||||
|
||||
// Handler returns an http.Handler for the API. This can be used on
|
||||
|
@ -34,19 +39,19 @@ func Handler(core *vault.Core) http.Handler {
|
|||
mux.Handle("/v1/sys/seal", handleSysSeal(core))
|
||||
mux.Handle("/v1/sys/step-down", handleSysStepDown(core))
|
||||
mux.Handle("/v1/sys/unseal", handleSysUnseal(core))
|
||||
mux.Handle("/v1/sys/renew", handleLogical(core, false, nil))
|
||||
mux.Handle("/v1/sys/renew/", handleLogical(core, false, nil))
|
||||
mux.Handle("/v1/sys/renew", handleRequestForwarding(core, handleLogical(core, false, nil)))
|
||||
mux.Handle("/v1/sys/renew/", handleRequestForwarding(core, handleLogical(core, false, nil)))
|
||||
mux.Handle("/v1/sys/leader", handleSysLeader(core))
|
||||
mux.Handle("/v1/sys/health", handleSysHealth(core))
|
||||
mux.Handle("/v1/sys/generate-root/attempt", handleSysGenerateRootAttempt(core))
|
||||
mux.Handle("/v1/sys/generate-root/update", handleSysGenerateRootUpdate(core))
|
||||
mux.Handle("/v1/sys/rekey/init", handleSysRekeyInit(core, false))
|
||||
mux.Handle("/v1/sys/rekey/update", handleSysRekeyUpdate(core, false))
|
||||
mux.Handle("/v1/sys/rekey-recovery-key/init", handleSysRekeyInit(core, true))
|
||||
mux.Handle("/v1/sys/rekey-recovery-key/update", handleSysRekeyUpdate(core, true))
|
||||
mux.Handle("/v1/sys/capabilities-self", handleLogical(core, true, sysCapabilitiesSelfCallback))
|
||||
mux.Handle("/v1/sys/", handleLogical(core, true, nil))
|
||||
mux.Handle("/v1/", handleLogical(core, false, nil))
|
||||
mux.Handle("/v1/sys/generate-root/attempt", handleRequestForwarding(core, handleSysGenerateRootAttempt(core)))
|
||||
mux.Handle("/v1/sys/generate-root/update", handleRequestForwarding(core, handleSysGenerateRootUpdate(core)))
|
||||
mux.Handle("/v1/sys/rekey/init", handleRequestForwarding(core, handleSysRekeyInit(core, false)))
|
||||
mux.Handle("/v1/sys/rekey/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, false)))
|
||||
mux.Handle("/v1/sys/rekey-recovery-key/init", handleRequestForwarding(core, handleSysRekeyInit(core, true)))
|
||||
mux.Handle("/v1/sys/rekey-recovery-key/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, true)))
|
||||
mux.Handle("/v1/sys/capabilities-self", handleRequestForwarding(core, handleLogical(core, true, sysCapabilitiesSelfCallback)))
|
||||
mux.Handle("/v1/sys/", handleRequestForwarding(core, handleLogical(core, true, nil)))
|
||||
mux.Handle("/v1/", handleRequestForwarding(core, handleLogical(core, false, nil)))
|
||||
|
||||
// Wrap the handler in another handler to trigger all help paths.
|
||||
handler := handleHelpHandler(mux, core)
|
||||
|
@ -89,6 +94,79 @@ func parseRequest(r *http.Request, out interface{}) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// handleRequestForwarding determines whether to forward a request or not,
|
||||
// falling back on the older behavior of redirecting the client
|
||||
func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get(vault.IntNoForwardingHeaderName) != "" {
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get(NoRequestForwardingHeaderName) != "" {
|
||||
// Forwarding explicitly disabled, fall back to previous behavior
|
||||
core.Logger().Printf("[TRACE] http/handleRequestForwarding: forwarding disabled by client request")
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Note: in an HA setup, this call will also ensure that connections to
|
||||
// the leader are set up, as that happens once the advertised cluster
|
||||
// values are read during this function
|
||||
isLeader, leaderAddr, err := core.Leader()
|
||||
if err != nil {
|
||||
if err == vault.ErrHANotEnabled {
|
||||
// Standalone node, serve request normally
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// Some internal error occurred
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
if isLeader {
|
||||
// No forwarding needed, we're leader
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if leaderAddr == "" {
|
||||
respondError(w, http.StatusInternalServerError, fmt.Errorf("node not active but active node not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt forwarding the request. If we cannot forward -- perhaps it's
|
||||
// been disabled on the active node -- this will return with an
|
||||
// ErrCannotForward and we simply fall back
|
||||
resp, err := core.ForwardRequest(r)
|
||||
if err != nil {
|
||||
if err == vault.ErrCannotForward {
|
||||
core.Logger().Printf("[TRACE] http/handleRequestForwarding: cannot forward (possibly disabled on active node), falling back")
|
||||
} else {
|
||||
core.Logger().Printf("[ERR] http/handleRequestForwarding: error forwarding request: %v", err)
|
||||
}
|
||||
|
||||
// Fall back to redirection
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read the body into a buffer so we can write it back out to the
|
||||
// original requestor
|
||||
buf := bytes.NewBuffer(nil)
|
||||
_, err = buf.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
core.Logger().Printf("[ERR] http/handleRequestForwarding: error reading response body: %v", err)
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
w.Write(buf.Bytes())
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// request is a helper to perform a request and properly exit in the
|
||||
// case of an error.
|
||||
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool) {
|
||||
|
@ -107,43 +185,43 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l
|
|||
// respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby
|
||||
func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
|
||||
// Request the leader address
|
||||
_, advertise, err := core.Leader()
|
||||
_, redirectAddr, err := core.Leader()
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is no leader, generate a 503 error
|
||||
if advertise == "" {
|
||||
if redirectAddr == "" {
|
||||
err = fmt.Errorf("no active Vault instance found")
|
||||
respondError(w, http.StatusServiceUnavailable, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the advertise location
|
||||
advertiseURL, err := url.Parse(advertise)
|
||||
// Parse the redirect location
|
||||
redirectURL, err := url.Parse(redirectAddr)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a redirect URL
|
||||
redirectURL := url.URL{
|
||||
Scheme: advertiseURL.Scheme,
|
||||
Host: advertiseURL.Host,
|
||||
finalURL := url.URL{
|
||||
Scheme: redirectURL.Scheme,
|
||||
Host: redirectURL.Host,
|
||||
Path: reqURL.Path,
|
||||
RawQuery: reqURL.RawQuery,
|
||||
}
|
||||
|
||||
// Ensure there is a scheme, default to https
|
||||
if redirectURL.Scheme == "" {
|
||||
redirectURL.Scheme = "https"
|
||||
if finalURL.Scheme == "" {
|
||||
finalURL.Scheme = "https"
|
||||
}
|
||||
|
||||
// If we have an address, redirect! We use a 307 code
|
||||
// because we don't actually know if its permanent and
|
||||
// the request method should be preserved.
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
w.Header().Set("Location", finalURL.String())
|
||||
w.WriteHeader(307)
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
func handleHelpHandler(h http.Handler, core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
// If the help parameter is not blank, then show the help
|
||||
if v := req.URL.Query().Get("help"); v != "" {
|
||||
if v := req.URL.Query().Get("help"); v != "" || req.Method == "HELP" {
|
||||
handleHelp(core, w, req)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -84,10 +84,10 @@ func TestLogical_StandbyRedirect(t *testing.T) {
|
|||
// Create an HA Vault
|
||||
inmha := physical.NewInmemHA(logger)
|
||||
conf := &vault.CoreConfig{
|
||||
Physical: inmha,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: addr1,
|
||||
DisableMlock: true,
|
||||
Physical: inmha,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: addr1,
|
||||
DisableMlock: true,
|
||||
}
|
||||
core1, err := vault.NewCore(conf)
|
||||
if err != nil {
|
||||
|
@ -104,10 +104,10 @@ func TestLogical_StandbyRedirect(t *testing.T) {
|
|||
|
||||
// Create a second HA Vault
|
||||
conf2 := &vault.CoreConfig{
|
||||
Physical: inmha,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: addr2,
|
||||
DisableMlock: true,
|
||||
Physical: inmha,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: addr2,
|
||||
DisableMlock: true,
|
||||
}
|
||||
core2, err := vault.NewCore(conf2)
|
||||
if err != nil {
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
package http
|
|
@ -60,8 +60,8 @@ type ConsulBackend struct {
|
|||
kv *api.KV
|
||||
permitPool *PermitPool
|
||||
serviceLock sync.RWMutex
|
||||
advertiseHost string
|
||||
advertisePort int64
|
||||
redirectHost string
|
||||
redirectPort int64
|
||||
serviceName string
|
||||
serviceTags []string
|
||||
disableRegistration bool
|
||||
|
@ -111,9 +111,9 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro
|
|||
logger.Printf("[DEBUG]: physical/consul: config service set to %s", service)
|
||||
|
||||
// Get the additional tags to attach to the registered service name
|
||||
tags := conf["service-tags"]
|
||||
tags := conf["service_tags"]
|
||||
|
||||
logger.Printf("[DEBUG]: physical/consul: config service-tags set to %s", tags)
|
||||
logger.Printf("[DEBUG]: physical/consul: config service_tags set to %s", tags)
|
||||
|
||||
checkTimeout := defaultCheckTimeout
|
||||
checkTimeoutStr, ok := conf["check_timeout"]
|
||||
|
@ -416,20 +416,20 @@ func (c *ConsulBackend) checkDuration() time.Duration {
|
|||
return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
|
||||
}
|
||||
|
||||
func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) {
|
||||
if err := c.setAdvertiseAddr(advertiseAddr); err != nil {
|
||||
func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) {
|
||||
if err := c.setRedirectAddr(redirectAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 'server' command will wait for the below goroutine to complete
|
||||
waitGroup.Add(1)
|
||||
|
||||
go c.runEventDemuxer(waitGroup, shutdownCh, advertiseAddr, activeFunc, sealedFunc)
|
||||
go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr, activeFunc, sealedFunc)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) {
|
||||
func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc activeFunction, sealedFunc sealedFunction) {
|
||||
// This defer statement should be executed last. So push it first.
|
||||
defer waitGroup.Done()
|
||||
|
||||
|
@ -532,7 +532,7 @@ func (c *ConsulBackend) checkID() string {
|
|||
// serviceID returns the Vault ServiceID for use in Consul. Assume at least
|
||||
// a read lock is held.
|
||||
func (c *ConsulBackend) serviceID() string {
|
||||
return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort)
|
||||
return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort)
|
||||
}
|
||||
|
||||
// reconcileConsul queries the state of Vault Core and Consul and fixes up
|
||||
|
@ -585,8 +585,8 @@ func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc a
|
|||
ID: serviceID,
|
||||
Name: c.serviceName,
|
||||
Tags: tags,
|
||||
Port: int(c.advertisePort),
|
||||
Address: c.advertiseHost,
|
||||
Port: int(c.redirectPort),
|
||||
Address: c.redirectHost,
|
||||
EnableTagOverride: false,
|
||||
}
|
||||
|
||||
|
@ -637,18 +637,18 @@ func (c *ConsulBackend) fetchServiceTags(active bool) []string {
|
|||
return append(c.serviceTags, activeTag)
|
||||
}
|
||||
|
||||
func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) {
|
||||
func (c *ConsulBackend) setRedirectAddr(addr string) (err error) {
|
||||
if addr == "" {
|
||||
return fmt.Errorf("advertise address must not be empty")
|
||||
return fmt.Errorf("redirect address must not be empty")
|
||||
}
|
||||
|
||||
url, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err)
|
||||
return errwrap.Wrapf(fmt.Sprintf(`failed to parse redirect URL "%v": {{err}}`, addr), err)
|
||||
}
|
||||
|
||||
var portStr string
|
||||
c.advertiseHost, portStr, err = net.SplitHostPort(url.Host)
|
||||
c.redirectHost, portStr, err = net.SplitHostPort(url.Host)
|
||||
if err != nil {
|
||||
if url.Scheme == "http" {
|
||||
portStr = "80"
|
||||
|
@ -656,13 +656,13 @@ func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) {
|
|||
portStr = "443"
|
||||
} else if url.Scheme == "unix" {
|
||||
portStr = "-1"
|
||||
c.advertiseHost = url.Path
|
||||
c.redirectHost = url.Path
|
||||
} else {
|
||||
return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
|
||||
return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in redirect address "%v": {{err}}`, url.Host), err)
|
||||
}
|
||||
}
|
||||
c.advertisePort, err = strconv.ParseInt(portStr, 10, 0)
|
||||
if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 {
|
||||
c.redirectPort, err = strconv.ParseInt(portStr, 10, 0)
|
||||
if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 {
|
||||
return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
|
||||
}
|
||||
|
||||
|
|
|
@ -78,8 +78,8 @@ func TestConsul_ServiceTags(t *testing.T) {
|
|||
consulConfig := map[string]string{
|
||||
"path": "seaTech/",
|
||||
"service": "astronomy",
|
||||
"service-tags": "deadbeef, cafeefac, deadc0de, feedface",
|
||||
"advertiseAddr": "http://127.0.0.2:8200",
|
||||
"service_tags": "deadbeef, cafeefac, deadc0de, feedface",
|
||||
"redirect_addr": "http://127.0.0.2:8200",
|
||||
"check_timeout": "6s",
|
||||
"address": "127.0.0.2",
|
||||
"scheme": "https",
|
||||
|
@ -112,38 +112,38 @@ func TestConsul_ServiceTags(t *testing.T) {
|
|||
|
||||
func TestConsul_newConsulBackend(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
consulConfig map[string]string
|
||||
fail bool
|
||||
advertiseAddr string
|
||||
checkTimeout time.Duration
|
||||
path string
|
||||
service string
|
||||
address string
|
||||
scheme string
|
||||
token string
|
||||
max_parallel int
|
||||
disableReg bool
|
||||
name string
|
||||
consulConfig map[string]string
|
||||
fail bool
|
||||
redirectAddr string
|
||||
checkTimeout time.Duration
|
||||
path string
|
||||
service string
|
||||
address string
|
||||
scheme string
|
||||
token string
|
||||
max_parallel int
|
||||
disableReg bool
|
||||
}{
|
||||
{
|
||||
name: "Valid default config",
|
||||
consulConfig: map[string]string{},
|
||||
checkTimeout: 5 * time.Second,
|
||||
advertiseAddr: "http://127.0.0.1:8200",
|
||||
path: "vault/",
|
||||
service: "vault",
|
||||
address: "127.0.0.1:8500",
|
||||
scheme: "http",
|
||||
token: "",
|
||||
max_parallel: 4,
|
||||
disableReg: false,
|
||||
name: "Valid default config",
|
||||
consulConfig: map[string]string{},
|
||||
checkTimeout: 5 * time.Second,
|
||||
redirectAddr: "http://127.0.0.1:8200",
|
||||
path: "vault/",
|
||||
service: "vault",
|
||||
address: "127.0.0.1:8500",
|
||||
scheme: "http",
|
||||
token: "",
|
||||
max_parallel: 4,
|
||||
disableReg: false,
|
||||
},
|
||||
{
|
||||
name: "Valid modified config",
|
||||
consulConfig: map[string]string{
|
||||
"path": "seaTech/",
|
||||
"service": "astronomy",
|
||||
"advertiseAddr": "http://127.0.0.2:8200",
|
||||
"redirect_addr": "http://127.0.0.2:8200",
|
||||
"check_timeout": "6s",
|
||||
"address": "127.0.0.2",
|
||||
"scheme": "https",
|
||||
|
@ -151,14 +151,14 @@ func TestConsul_newConsulBackend(t *testing.T) {
|
|||
"max_parallel": "4",
|
||||
"disable_registration": "false",
|
||||
},
|
||||
checkTimeout: 6 * time.Second,
|
||||
path: "seaTech/",
|
||||
service: "astronomy",
|
||||
advertiseAddr: "http://127.0.0.2:8200",
|
||||
address: "127.0.0.2",
|
||||
scheme: "https",
|
||||
token: "deadbeef-cafeefac-deadc0de-feedface",
|
||||
max_parallel: 4,
|
||||
checkTimeout: 6 * time.Second,
|
||||
path: "seaTech/",
|
||||
service: "astronomy",
|
||||
redirectAddr: "http://127.0.0.2:8200",
|
||||
address: "127.0.0.2",
|
||||
scheme: "https",
|
||||
token: "deadbeef-cafeefac-deadc0de-feedface",
|
||||
max_parallel: 4,
|
||||
},
|
||||
{
|
||||
name: "check timeout too short",
|
||||
|
@ -197,7 +197,7 @@ func TestConsul_newConsulBackend(t *testing.T) {
|
|||
|
||||
var shutdownCh ShutdownChannel
|
||||
waitGroup := &sync.WaitGroup{}
|
||||
if err := c.RunServiceDiscovery(waitGroup, shutdownCh, test.advertiseAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil {
|
||||
if err := c.RunServiceDiscovery(waitGroup, shutdownCh, test.redirectAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
|
@ -245,7 +245,7 @@ func TestConsul_serviceTags(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConsul_setAdvertiseAddr(t *testing.T) {
|
||||
func TestConsul_setRedirectAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
host string
|
||||
|
@ -287,7 +287,7 @@ func TestConsul_setAdvertiseAddr(t *testing.T) {
|
|||
}
|
||||
for _, test := range tests {
|
||||
c := testConsulBackend(t)
|
||||
err := c.setAdvertiseAddr(test.addr)
|
||||
err := c.setRedirectAddr(test.addr)
|
||||
if test.pass {
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
|
@ -300,12 +300,12 @@ func TestConsul_setAdvertiseAddr(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
if c.advertiseHost != test.host {
|
||||
t.Fatalf("bad: %v != %v", c.advertiseHost, test.host)
|
||||
if c.redirectHost != test.host {
|
||||
t.Fatalf("bad: %v != %v", c.redirectHost, test.host)
|
||||
}
|
||||
|
||||
if c.advertisePort != test.port {
|
||||
t.Fatalf("bad: %v != %v", c.advertisePort, test.port)
|
||||
if c.redirectPort != test.port {
|
||||
t.Fatalf("bad: %v != %v", c.redirectPort, test.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -338,28 +338,28 @@ func TestConsul_NotifySealedStateChange(t *testing.T) {
|
|||
|
||||
func TestConsul_serviceID(t *testing.T) {
|
||||
passingTests := []struct {
|
||||
name string
|
||||
advertiseAddr string
|
||||
serviceName string
|
||||
expected string
|
||||
name string
|
||||
redirectAddr string
|
||||
serviceName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid host w/o slash",
|
||||
advertiseAddr: "http://127.0.0.1:8200",
|
||||
serviceName: "sea-tech-astronomy",
|
||||
expected: "sea-tech-astronomy:127.0.0.1:8200",
|
||||
name: "valid host w/o slash",
|
||||
redirectAddr: "http://127.0.0.1:8200",
|
||||
serviceName: "sea-tech-astronomy",
|
||||
expected: "sea-tech-astronomy:127.0.0.1:8200",
|
||||
},
|
||||
{
|
||||
name: "valid host w/ slash",
|
||||
advertiseAddr: "http://127.0.0.1:8200/",
|
||||
serviceName: "sea-tech-astronomy",
|
||||
expected: "sea-tech-astronomy:127.0.0.1:8200",
|
||||
name: "valid host w/ slash",
|
||||
redirectAddr: "http://127.0.0.1:8200/",
|
||||
serviceName: "sea-tech-astronomy",
|
||||
expected: "sea-tech-astronomy:127.0.0.1:8200",
|
||||
},
|
||||
{
|
||||
name: "valid https host w/ slash",
|
||||
advertiseAddr: "https://127.0.0.1:8200/",
|
||||
serviceName: "sea-tech-astronomy",
|
||||
expected: "sea-tech-astronomy:127.0.0.1:8200",
|
||||
name: "valid https host w/ slash",
|
||||
redirectAddr: "https://127.0.0.1:8200/",
|
||||
serviceName: "sea-tech-astronomy",
|
||||
expected: "sea-tech-astronomy:127.0.0.1:8200",
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -368,7 +368,7 @@ func TestConsul_serviceID(t *testing.T) {
|
|||
"service": test.serviceName,
|
||||
})
|
||||
|
||||
if err := c.setAdvertiseAddr(test.advertiseAddr); err != nil {
|
||||
if err := c.setRedirectAddr(test.redirectAddr); err != nil {
|
||||
t.Fatalf("bad: %s %v", test.name, err)
|
||||
}
|
||||
|
||||
|
@ -445,9 +445,9 @@ func TestConsulHABackend(t *testing.T) {
|
|||
}
|
||||
testHABackend(t, ha, ha)
|
||||
|
||||
detect, ok := b.(AdvertiseDetect)
|
||||
detect, ok := b.(RedirectDetect)
|
||||
if !ok {
|
||||
t.Fatalf("consul does not implement AdvertiseDetect")
|
||||
t.Fatalf("consul does not implement RedirectDetect")
|
||||
}
|
||||
host, err := detect.DetectHostAddr()
|
||||
if err != nil {
|
||||
|
|
|
@ -44,10 +44,10 @@ type HABackend interface {
|
|||
HAEnabled() bool
|
||||
}
|
||||
|
||||
// AdvertiseDetect is an optional interface that an HABackend
|
||||
// can implement. If they do, an advertise address can be automatically
|
||||
// RedirectDetect is an optional interface that an HABackend
|
||||
// can implement. If they do, a redirect address can be automatically
|
||||
// detected.
|
||||
type AdvertiseDetect interface {
|
||||
type RedirectDetect interface {
|
||||
// DetectHostAddr is used to detect the host address
|
||||
DetectHostAddr() (string, error)
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ type ServiceDiscovery interface {
|
|||
|
||||
// Run executes any background service discovery tasks until the
|
||||
// shutdown channel is closed.
|
||||
RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error
|
||||
RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error
|
||||
}
|
||||
|
||||
type Lock interface {
|
||||
|
@ -114,6 +114,9 @@ var builtinBackends = map[string]Factory{
|
|||
"inmem": func(_ map[string]string, logger *log.Logger) (Backend, error) {
|
||||
return NewInmem(logger), nil
|
||||
},
|
||||
"inmem_ha": func(_ map[string]string, logger *log.Logger) (Backend, error) {
|
||||
return NewInmemHA(logger), nil
|
||||
},
|
||||
"consul": newConsulBackend,
|
||||
"zookeeper": newZookeeperBackend,
|
||||
"file": newFileBackend,
|
||||
|
|
|
@ -85,7 +85,7 @@ func TestCore_EnableAudit(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func TestCore_DisableAudit(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ func TestCore_DefaultAuditTable(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ func TestCore_DefaultAuthTable(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func TestCore_EnableCredential(t *testing.T) {
|
|||
c2.credentialBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) {
|
||||
return &NoopBackend{}, nil
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ func TestCore_DisableCredential(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
389
vault/cluster.go
389
vault/cluster.go
|
@ -1,18 +1,55 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/requestutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// Storage path where the local cluster name and identifier are stored
|
||||
coreLocalClusterInfoPath = "core/cluster/local/info"
|
||||
|
||||
corePrivateKeyTypeP521 = "p521"
|
||||
|
||||
// Internal so as not to log a trace message
|
||||
IntNoForwardingHeaderName = "X-Vault-Internal-No-Request-Forwarding"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCannotForward = errors.New("cannot forward request; no connection or address not known")
|
||||
)
|
||||
|
||||
type clusterKeyParams struct {
|
||||
Type string `json:"type"`
|
||||
X *big.Int `json:"x"`
|
||||
Y *big.Int `json:"y"`
|
||||
D *big.Int `json:"d"`
|
||||
}
|
||||
|
||||
type activeConnection struct {
|
||||
*http.Client
|
||||
clusterAddr string
|
||||
}
|
||||
|
||||
// Structure representing the storage entry that holds cluster information
|
||||
type Cluster struct {
|
||||
// Name of the cluster
|
||||
|
@ -49,10 +86,50 @@ func (c *Core) Cluster() (*Cluster, error) {
|
|||
return &cluster, nil
|
||||
}
|
||||
|
||||
// This is idempotent, so we return nil if there is no entry yet (say, because
|
||||
// the active node has not yet generated this)
|
||||
func (c *Core) loadClusterTLS(adv activeAdvertisement) error {
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
switch {
|
||||
case adv.ClusterKeyParams.X == nil, adv.ClusterKeyParams.Y == nil, adv.ClusterKeyParams.D == nil:
|
||||
c.logger.Printf("[ERR] core/loadClusterPrivateKey: failed to parse local cluster key due to missing params")
|
||||
return fmt.Errorf("failed to parse local cluster key")
|
||||
case adv.ClusterKeyParams.Type == corePrivateKeyTypeP521:
|
||||
default:
|
||||
c.logger.Printf("[ERR] core/loadClusterPrivateKey: unknown local cluster key type %v", adv.ClusterKeyParams.Type)
|
||||
return fmt.Errorf("failed to find valid local cluster key type")
|
||||
}
|
||||
c.localClusterPrivateKey = &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P521(),
|
||||
X: adv.ClusterKeyParams.X,
|
||||
Y: adv.ClusterKeyParams.Y,
|
||||
},
|
||||
D: adv.ClusterKeyParams.D,
|
||||
}
|
||||
|
||||
c.localClusterCert = adv.ClusterCert
|
||||
|
||||
cert, err := x509.ParseCertificate(c.localClusterCert)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core/loadClusterPrivateKey: failed parsing local cluster certificate: %v", err)
|
||||
return fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
c.localClusterCertPool.AddCert(cert)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupCluster creates storage entries for holding Vault cluster information.
|
||||
// Entries will be created only if they are not already present. If clusterName
|
||||
// is not supplied, this method will auto-generate it.
|
||||
func (c *Core) setupCluster() error {
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
// Check if storage index is already present or not
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
|
@ -60,12 +137,12 @@ func (c *Core) setupCluster() error {
|
|||
return err
|
||||
}
|
||||
|
||||
var modified bool
|
||||
|
||||
if cluster == nil {
|
||||
cluster = &Cluster{}
|
||||
}
|
||||
|
||||
var modified bool
|
||||
|
||||
if cluster.Name == "" {
|
||||
// If cluster name is not supplied, generate one
|
||||
if c.clusterName == "" {
|
||||
|
@ -75,6 +152,7 @@ func (c *Core) setupCluster() error {
|
|||
c.logger.Printf("[ERR] core: failed to generate cluster name: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.clusterName = fmt.Sprintf("vault-cluster-%08x", clusterNameBytes)
|
||||
}
|
||||
|
||||
|
@ -84,17 +162,71 @@ func (c *Core) setupCluster() error {
|
|||
}
|
||||
|
||||
if cluster.ID == "" {
|
||||
c.logger.Printf("[TRACE] core: cluster ID not found, generating new")
|
||||
// Generate a clusterID
|
||||
clusterID, err := uuid.GenerateUUID()
|
||||
cluster.ID, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to generate cluster identifier: %v", err)
|
||||
return err
|
||||
}
|
||||
cluster.ID = clusterID
|
||||
c.logger.Printf("[DEBUG] core: cluster ID set to %s", cluster.ID)
|
||||
modified = true
|
||||
}
|
||||
|
||||
// Create a private key
|
||||
{
|
||||
c.logger.Printf("[TRACE] core: generating cluster private key")
|
||||
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: failed to generate local cluster key: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.localClusterPrivateKey = key
|
||||
}
|
||||
|
||||
// Create a certificate
|
||||
{
|
||||
c.logger.Printf("[TRACE] core: generating local cluster certificate")
|
||||
|
||||
host, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
},
|
||||
DNSNames: []string{host},
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
// 30 years of single-active uptime ought to be enough for anybody
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, c.localClusterPrivateKey.Public(), c.localClusterPrivateKey)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: error generating self-signed cert: %v", err)
|
||||
return fmt.Errorf("unable to generate local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
_, err = x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core: error parsing self-signed cert: %v", err)
|
||||
return fmt.Errorf("error parsing generated certificate: %v", err)
|
||||
}
|
||||
|
||||
c.localClusterCert = certBytes
|
||||
}
|
||||
|
||||
if modified {
|
||||
// Encode the cluster information into as a JSON string
|
||||
rawCluster, err := json.Marshal(cluster)
|
||||
|
@ -116,3 +248,252 @@ func (c *Core) setupCluster() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetClusterListenerSetupFunc sets the listener setup func, which is used to
|
||||
// know which ports to listen on and a handler to use.
|
||||
func (c *Core) SetClusterListenerSetupFunc(setupFunc func() ([]net.Listener, http.Handler, error)) {
|
||||
c.clusterListenerSetupFunc = setupFunc
|
||||
}
|
||||
|
||||
// startClusterListener starts cluster request listeners during postunseal. It
|
||||
// is assumed that the state lock is held while this is run.
|
||||
func (c *Core) startClusterListener() error {
|
||||
if c.clusterListenerShutdownCh != nil {
|
||||
c.logger.Printf("[ERR] core/startClusterListener: attempt to set up cluster listeners when already set up")
|
||||
return fmt.Errorf("cluster listeners already setup")
|
||||
}
|
||||
|
||||
if c.clusterListenerSetupFunc == nil {
|
||||
c.logger.Printf("[ERR] core/startClusterListener: cluster listener setup function has not been set")
|
||||
return fmt.Errorf("cluster listener setup function has not been set")
|
||||
}
|
||||
|
||||
if c.clusterAddr == "" {
|
||||
c.logger.Printf("[TRACE] core/startClusterListener: clustering disabled, starting listeners")
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Printf("[TRACE] core/startClusterListener: starting listeners")
|
||||
|
||||
lns, handler, err := c.clusterListenerSetupFunc()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tlsConfig, err := c.ClusterTLSConfig()
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core/startClusterListener: failed to get tls configuration: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
tlsLns := make([]net.Listener, 0, len(lns))
|
||||
for _, ln := range lns {
|
||||
tlsLn := tls.NewListener(ln, tlsConfig)
|
||||
tlsLns = append(tlsLns, tlsLn)
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
http2.ConfigureServer(server, nil)
|
||||
c.logger.Printf("[TRACE] core/startClusterListener: serving cluster requests on %s", tlsLn.Addr())
|
||||
go server.Serve(tlsLn)
|
||||
}
|
||||
|
||||
c.clusterListenerShutdownCh = make(chan struct{})
|
||||
c.clusterListenerShutdownSuccessCh = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
<-c.clusterListenerShutdownCh
|
||||
c.logger.Printf("[TRACE] core/startClusterListener: shutting down listeners")
|
||||
for _, tlsLn := range tlsLns {
|
||||
tlsLn.Close()
|
||||
}
|
||||
close(c.clusterListenerShutdownSuccessCh)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopClusterListener stops any existing listeners during preseal. It is
|
||||
// assumed that the state lock is held while this is run.
|
||||
func (c *Core) stopClusterListener() {
|
||||
c.logger.Printf("[TRACE] core/stopClusterListener: stopping listeners")
|
||||
if c.clusterListenerShutdownCh != nil {
|
||||
close(c.clusterListenerShutdownCh)
|
||||
defer func() { c.clusterListenerShutdownCh = nil }()
|
||||
}
|
||||
|
||||
// The reason for this loop-de-loop is that we may be unsealing again
|
||||
// quickly, and if the listeners are not yet closed, we will get socket
|
||||
// bind errors. This ensures proper ordering.
|
||||
if c.clusterListenerShutdownSuccessCh == nil {
|
||||
return
|
||||
}
|
||||
<-c.clusterListenerShutdownSuccessCh
|
||||
defer func() { c.clusterListenerShutdownSuccessCh = nil }()
|
||||
}
|
||||
|
||||
// ClusterTLSConfig generates a TLS configuration based on the local cluster
|
||||
// key and cert. This isn't called often and we lock because the CertPool is
|
||||
// not concurrency-safe.
|
||||
func (c *Core) ClusterTLSConfig() (*tls.Config, error) {
|
||||
c.clusterParamsLock.Lock()
|
||||
defer c.clusterParamsLock.Unlock()
|
||||
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cluster == nil {
|
||||
return nil, fmt.Errorf("cluster information is nil")
|
||||
}
|
||||
if c.localClusterCert == nil || len(c.localClusterCert) == 0 {
|
||||
return nil, fmt.Errorf("cluster certificate is nil")
|
||||
}
|
||||
|
||||
parsedCert, err := x509.ParseCertificate(c.localClusterCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing local cluster certificate: %v", err)
|
||||
}
|
||||
|
||||
// This is idempotent, so be sure it's been added
|
||||
c.localClusterCertPool.AddCert(parsedCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
tls.Certificate{
|
||||
Certificate: [][]byte{c.localClusterCert},
|
||||
PrivateKey: c.localClusterPrivateKey,
|
||||
},
|
||||
},
|
||||
RootCAs: c.localClusterCertPool,
|
||||
ServerName: parsedCert.Subject.CommonName,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: c.localClusterCertPool,
|
||||
NextProtos: []string{
|
||||
"h2",
|
||||
},
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// refreshRequestForwardingConnection ensures that the client/transport are
|
||||
// alive and that the current active address value matches the most
|
||||
// recently-known address.
|
||||
func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
|
||||
c.requestForwardingConnectionLock.Lock()
|
||||
defer c.requestForwardingConnectionLock.Unlock()
|
||||
|
||||
// It's nil but we don't have an address anyways, so exit
|
||||
if c.requestForwardingConnection == nil && clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE: We don't fast path the case where we have a connection because the
|
||||
// address is the same, because the cert/key could have changed if the
|
||||
// active node ended up being the same node. Before we hit this function in
|
||||
// Leader() we'll have done a hash on the advertised info to ensure that we
|
||||
// won't hit this function unnecessarily anyways.
|
||||
|
||||
// Disabled, potentially
|
||||
if clusterAddr == "" {
|
||||
c.requestForwardingConnection = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConfig, err := c.ClusterTLSConfig()
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core/refreshRequestForwardingConnection: error fetching cluster tls configuration: %v", err)
|
||||
return err
|
||||
}
|
||||
tp := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
err = http2.ConfigureTransport(tp)
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core/refreshRequestForwardingConnection: error configuring transport: %v", err)
|
||||
return err
|
||||
}
|
||||
c.requestForwardingConnection = &activeConnection{
|
||||
Client: &http.Client{
|
||||
Transport: tp,
|
||||
},
|
||||
clusterAddr: clusterAddr,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardRequest forwards a given request to the active node and returns the
|
||||
// response.
|
||||
func (c *Core) ForwardRequest(req *http.Request) (*http.Response, error) {
|
||||
c.requestForwardingConnectionLock.RLock()
|
||||
defer c.requestForwardingConnectionLock.RUnlock()
|
||||
if c.requestForwardingConnection == nil {
|
||||
return nil, ErrCannotForward
|
||||
}
|
||||
|
||||
if c.requestForwardingConnection.clusterAddr == "" {
|
||||
return nil, ErrCannotForward
|
||||
}
|
||||
|
||||
freq, err := requestutil.GenerateForwardedRequest(req, c.requestForwardingConnection.clusterAddr+"/cluster/forwarded-request")
|
||||
if err != nil {
|
||||
c.logger.Printf("[ERR] core/ForwardRequest: error creating forwarded request: %v", err)
|
||||
return nil, fmt.Errorf("error creating forwarding request")
|
||||
}
|
||||
|
||||
return c.requestForwardingConnection.Do(freq)
|
||||
}
|
||||
|
||||
// WrapListenersForClustering takes in Vault's listeners and original HTTP
|
||||
// handler, creates a new handler that handles forwarded requests, and returns
|
||||
// the cluster setup function that creates the new listners and assigns to the
|
||||
// new handler
|
||||
func WrapListenersForClustering(addrs []string, handler http.Handler, logger *log.Logger) func() ([]net.Listener, http.Handler, error) {
|
||||
// This mux handles cluster functions (right now, only forwarded requests)
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/cluster/forwarded-request", func(w http.ResponseWriter, req *http.Request) {
|
||||
freq, err := requestutil.ParseForwardedRequest(req)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Printf("[ERR] http/ForwardedRequestHandler: error parsing forwarded request: %v", err)
|
||||
}
|
||||
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
type errorResponse struct {
|
||||
Errors []string
|
||||
}
|
||||
resp := &errorResponse{
|
||||
Errors: []string{
|
||||
err.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
// To avoid the risk of a forward loop in some pathological condition,
|
||||
// set the no-forward header
|
||||
freq.Header.Set(IntNoForwardingHeaderName, "true")
|
||||
handler.ServeHTTP(w, freq)
|
||||
})
|
||||
|
||||
return func() ([]net.Listener, http.Handler, error) {
|
||||
ret := make([]net.Listener, 0, len(addrs))
|
||||
// Loop over the existing listeners and start listeners on appropriate ports
|
||||
for _, addr := range addrs {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ret = append(ret, ln)
|
||||
}
|
||||
|
||||
return ret, mux, nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,353 @@
|
|||
package vault
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
func TestCluster(t *testing.T) {
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
)
|
||||
|
||||
func TestClusterFetching(t *testing.T) {
|
||||
c, _, _ := TestCoreUnsealed(t)
|
||||
|
||||
err := c.setupCluster()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Test whether expected values are found
|
||||
if cluster == nil || cluster.Name == "" || cluster.ID == "" {
|
||||
t.Fatalf("cluster information missing: cluster: %#v", cluster)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClusterHAFetching(t *testing.T) {
|
||||
logger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
redirect := "http://127.0.0.1:8200"
|
||||
|
||||
c, err := NewCore(&CoreConfig{
|
||||
Physical: physical.NewInmemHA(logger),
|
||||
HAPhysical: physical.NewInmemHA(logger),
|
||||
RedirectAddr: redirect,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, _ := TestCoreInit(t, c)
|
||||
if _, err := TestCoreUnseal(c, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
// Verify unsealed
|
||||
sealed, err := c.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
|
||||
// Wait for core to become active
|
||||
TestWaitActive(t, c)
|
||||
|
||||
cluster, err := c.Cluster()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Test whether expected values are found
|
||||
if cluster == nil || cluster.Name == "" || cluster.ID == "" {
|
||||
t.Fatalf("cluster information missing: cluster:%#v", cluster)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCluster_ListenForRequests(t *testing.T) {
|
||||
// Make this nicer for tests
|
||||
manualStepDownSleepPeriod = 5 * time.Second
|
||||
|
||||
cores := TestCluster(t, []http.Handler{nil, nil, nil}, nil, false)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
// Wait for core to become active
|
||||
TestWaitActive(t, cores[0].Core)
|
||||
|
||||
checkListenersFunc := func(expectFail bool) {
|
||||
tlsConfig, err := cores[0].ClusterTLSConfig()
|
||||
if err != nil && err.Error() != ErrSealed.Error() {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, ln := range cores[0].Listeners {
|
||||
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatal("%s not a TCP port", tcpAddr.String())
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port+1), tlsConfig)
|
||||
if err != nil {
|
||||
if expectFail {
|
||||
t.Logf("testing %s:%d unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+1)
|
||||
continue
|
||||
}
|
||||
t.Fatalf("error: %v\nlisteners are\n%#v\n%#v\n", err, cores[0].Listeners[0], cores[0].Listeners[1])
|
||||
}
|
||||
if expectFail {
|
||||
t.Fatalf("testing %s:%d not unsuccessful as expected", tcpAddr.IP.String(), tcpAddr.Port+1)
|
||||
}
|
||||
err = conn.Handshake()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
connState := conn.ConnectionState()
|
||||
switch {
|
||||
case connState.Version != tls.VersionTLS12:
|
||||
t.Fatal("version mismatch")
|
||||
case connState.NegotiatedProtocol != "h2" || !connState.NegotiatedProtocolIsMutual:
|
||||
t.Fatal("bad protocol negotiation")
|
||||
}
|
||||
t.Logf("testing %s:%d successful", tcpAddr.IP.String(), tcpAddr.Port+1)
|
||||
}
|
||||
}
|
||||
|
||||
checkListenersFunc(false)
|
||||
|
||||
err := cores[0].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// StepDown doesn't wait during actual preSeal so give time for listeners
|
||||
// to close
|
||||
time.Sleep(1 * time.Second)
|
||||
checkListenersFunc(true)
|
||||
|
||||
// After this period it should be active again
|
||||
time.Sleep(manualStepDownSleepPeriod)
|
||||
checkListenersFunc(false)
|
||||
|
||||
err = cores[0].Seal(root)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
// After sealing it should be inactive again
|
||||
checkListenersFunc(true)
|
||||
}
|
||||
|
||||
func TestCluster_ForwardRequests(t *testing.T) {
|
||||
// Make this nicer for tests
|
||||
manualStepDownSleepPeriod = 5 * time.Second
|
||||
|
||||
handler1 := http.NewServeMux()
|
||||
handler1.HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(201)
|
||||
w.Write([]byte("core1"))
|
||||
})
|
||||
handler2 := http.NewServeMux()
|
||||
handler2.HandleFunc("/core2", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(202)
|
||||
w.Write([]byte("core2"))
|
||||
})
|
||||
handler3 := http.NewServeMux()
|
||||
handler3.HandleFunc("/core3", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(203)
|
||||
w.Write([]byte("core3"))
|
||||
})
|
||||
|
||||
cores := TestCluster(t, []http.Handler{handler1, handler2, handler3}, nil, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
// Wait for core to become active
|
||||
TestWaitActive(t, cores[0].Core)
|
||||
|
||||
// Test forwarding a request. Since we're going directly from core to core
|
||||
// with no fallback we know that if it worked, request handling is working
|
||||
testCluster_ForwardRequests(t, cores[1], "core1")
|
||||
testCluster_ForwardRequests(t, cores[2], "core1")
|
||||
|
||||
//
|
||||
// Now we do a bunch of round-robining. The point is to make sure that as
|
||||
// nodes come and go, we can always successfully forward to the active
|
||||
// node.
|
||||
//
|
||||
|
||||
// Ensure active core is cores[1] and test
|
||||
err := cores[0].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
_ = cores[2].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
time.Sleep(2 * time.Second)
|
||||
TestWaitActive(t, cores[1].Core)
|
||||
testCluster_ForwardRequests(t, cores[0], "core2")
|
||||
testCluster_ForwardRequests(t, cores[2], "core2")
|
||||
|
||||
// Ensure active core is cores[2] and test
|
||||
err = cores[1].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
_ = cores[0].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
time.Sleep(2 * time.Second)
|
||||
TestWaitActive(t, cores[2].Core)
|
||||
testCluster_ForwardRequests(t, cores[0], "core3")
|
||||
testCluster_ForwardRequests(t, cores[1], "core3")
|
||||
|
||||
// Ensure active core is cores[0] and test
|
||||
err = cores[2].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
_ = cores[1].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
time.Sleep(2 * time.Second)
|
||||
TestWaitActive(t, cores[0].Core)
|
||||
testCluster_ForwardRequests(t, cores[1], "core1")
|
||||
testCluster_ForwardRequests(t, cores[2], "core1")
|
||||
|
||||
// Ensure active core is cores[1] and test
|
||||
err = cores[0].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
_ = cores[2].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
time.Sleep(2 * time.Second)
|
||||
TestWaitActive(t, cores[1].Core)
|
||||
testCluster_ForwardRequests(t, cores[0], "core2")
|
||||
testCluster_ForwardRequests(t, cores[2], "core2")
|
||||
|
||||
// Ensure active core is cores[2] and test
|
||||
err = cores[1].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
_ = cores[0].StepDown(&logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "sys/step-down",
|
||||
ClientToken: root,
|
||||
})
|
||||
time.Sleep(2 * time.Second)
|
||||
TestWaitActive(t, cores[2].Core)
|
||||
testCluster_ForwardRequests(t, cores[0], "core3")
|
||||
testCluster_ForwardRequests(t, cores[1], "core3")
|
||||
}
|
||||
|
||||
func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, remoteCoreID string) {
|
||||
standby, err := c.Standby()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !standby {
|
||||
t.Fatal("expected core to be standby")
|
||||
}
|
||||
|
||||
// We need to call Leader as that refreshes the connection info
|
||||
isLeader, _, err := c.Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatal("core should not be leader")
|
||||
}
|
||||
|
||||
bodBuf := bytes.NewReader([]byte(`{ "foo": "bar", "zip": "zap" }`))
|
||||
req, err := http.NewRequest("PUT", "https://pushit.real.good:9281/"+remoteCoreID, bodBuf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add("X-Vault-Token", c.Root)
|
||||
|
||||
resp, err := c.ForwardRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil resp")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body := bytes.NewBuffer(nil)
|
||||
body.ReadFrom(resp.Body)
|
||||
|
||||
if body.String() != remoteCoreID {
|
||||
t.Fatalf("expected %s, got %s", remoteCoreID, body.String())
|
||||
}
|
||||
switch body.String() {
|
||||
case "core1":
|
||||
if resp.StatusCode != 201 {
|
||||
t.Fatal("bad response")
|
||||
}
|
||||
case "core2":
|
||||
if resp.StatusCode != 202 {
|
||||
t.Fatal("bad response")
|
||||
}
|
||||
case "core3":
|
||||
if resp.StatusCode != 203 {
|
||||
t.Fatal("bad response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
227
vault/core.go
227
vault/core.go
|
@ -2,9 +2,15 @@ package vault
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
|
@ -16,6 +22,7 @@ import (
|
|||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/mlock"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
|
@ -46,10 +53,6 @@ const (
|
|||
// leaderPrefixCleanDelay is how long to wait between deletions
|
||||
// of orphaned leader keys, to prevent slamming the backend.
|
||||
leaderPrefixCleanDelay = 200 * time.Millisecond
|
||||
|
||||
// manualStepDownSleepPeriod is how long to sleep after a user-initiated
|
||||
// step down of the active node, to prevent instantly regrabbing the lock
|
||||
manualStepDownSleepPeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -76,6 +79,11 @@ var (
|
|||
// ErrHANotEnabled is returned if the operation only makes sense
|
||||
// in an HA setting
|
||||
ErrHANotEnabled = errors.New("Vault is not configured for highly-available mode")
|
||||
|
||||
// manualStepDownSleepPeriod is how long to sleep after a user-initiated
|
||||
// step down of the active node, to prevent instantly regrabbing the lock.
|
||||
// It's var not const so that tests can manipulate it.
|
||||
manualStepDownSleepPeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
// NonFatalError is an error that can be returned during NewCore that should be
|
||||
|
@ -102,6 +110,13 @@ func (e *ErrInvalidKey) Error() string {
|
|||
return fmt.Sprintf("invalid key: %v", e.Reason)
|
||||
}
|
||||
|
||||
type activeAdvertisement struct {
|
||||
RedirectAddr string `json:"redirect_addr"`
|
||||
ClusterAddr string `json:"cluster_addr"`
|
||||
ClusterCert []byte `json:"cluster_cert"`
|
||||
ClusterKeyParams clusterKeyParams `json:"cluster_key_params"`
|
||||
}
|
||||
|
||||
// Core is used as the central manager of Vault activity. It is the primary point of
|
||||
// interface for API handlers and is responsible for managing the logical and physical
|
||||
// backends, router, security barrier, and audit trails.
|
||||
|
@ -109,8 +124,11 @@ type Core struct {
|
|||
// HABackend may be available depending on the physical backend
|
||||
ha physical.HABackend
|
||||
|
||||
// AdvertiseAddr is the address we advertise as leader if held
|
||||
advertiseAddr string
|
||||
// redirectAddr is the address we advertise as leader if held
|
||||
redirectAddr string
|
||||
|
||||
// clusterAddr is the address we use for clustering
|
||||
clusterAddr string
|
||||
|
||||
// physical backend is the un-trusted backend with durable data
|
||||
physical physical.Backend
|
||||
|
@ -220,7 +238,39 @@ type Core struct {
|
|||
// cachingDisabled indicates whether caches are disabled
|
||||
cachingDisabled bool
|
||||
|
||||
//
|
||||
// Cluster information
|
||||
//
|
||||
// Name
|
||||
clusterName string
|
||||
// Used to modify cluster TLS params
|
||||
clusterParamsLock sync.RWMutex
|
||||
// The private key stored in the barrier used for establishing
|
||||
// mutually-authenticated connections between Vault cluster members
|
||||
localClusterPrivateKey crypto.Signer
|
||||
// The local cluster cert
|
||||
localClusterCert []byte
|
||||
// The cert pool containing the self-signed CA as a trusted CA
|
||||
localClusterCertPool *x509.CertPool
|
||||
// The setup function that gives us the listeners for the cluster-cluster
|
||||
// connection and the handler to use
|
||||
clusterListenerSetupFunc func() ([]net.Listener, http.Handler, error)
|
||||
// Shutdown channel for the cluster listeners
|
||||
clusterListenerShutdownCh chan struct{}
|
||||
// Shutdown success channel. We need this to be done serially to ensure
|
||||
// that binds are removed before they might be reinstated.
|
||||
clusterListenerShutdownSuccessCh chan struct{}
|
||||
// Connection info containing a client and a current active address
|
||||
requestForwardingConnection *activeConnection
|
||||
// Write lock used to ensure that we don't have multiple connections adjust
|
||||
// this value at the same time
|
||||
requestForwardingConnectionLock sync.RWMutex
|
||||
// Most recent hashed value of the advertise/cluster info. Used to avoid
|
||||
// repeatedly JSON parsing the same values.
|
||||
clusterActiveAdvertisementHash []byte
|
||||
// Cache of most recently known active advertisement information, used to
|
||||
// return values when the hash matches
|
||||
clusterActiveAdvertisement activeAdvertisement
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -250,7 +300,10 @@ type CoreConfig struct {
|
|||
CacheSize int `json:"cache_size" structs:"cache_size" mapstructure:"cache_size"`
|
||||
|
||||
// Set as the leader address for HA
|
||||
AdvertiseAddr string `json:"advertise_addr" structs:"advertise_addr" mapstructure:"advertise_addr"`
|
||||
RedirectAddr string `json:"redirect_addr" structs:"redirect_addr" mapstructure:"redirect_addr"`
|
||||
|
||||
// Set as the cluster address for HA
|
||||
ClusterAddr string `json:"cluster_addr" structs:"cluster_addr" mapstructure:"cluster_addr"`
|
||||
|
||||
DefaultLeaseTTL time.Duration `json:"default_lease_ttl" structs:"default_lease_ttl" mapstructure:"default_lease_ttl"`
|
||||
|
||||
|
@ -261,8 +314,10 @@ type CoreConfig struct {
|
|||
|
||||
// NewCore is used to construct a new core
|
||||
func NewCore(conf *CoreConfig) (*Core, error) {
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() && conf.AdvertiseAddr == "" {
|
||||
return nil, fmt.Errorf("missing advertisement address")
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
if conf.RedirectAddr == "" {
|
||||
return nil, fmt.Errorf("missing advertisement address")
|
||||
}
|
||||
}
|
||||
|
||||
if conf.DefaultLeaseTTL == 0 {
|
||||
|
@ -276,8 +331,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
}
|
||||
|
||||
// Validate the advertise addr if its given to us
|
||||
if conf.AdvertiseAddr != "" {
|
||||
u, err := url.Parse(conf.AdvertiseAddr)
|
||||
if conf.RedirectAddr != "" {
|
||||
u, err := url.Parse(conf.RedirectAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("advertisement address is not valid url: %s", err)
|
||||
}
|
||||
|
@ -326,18 +381,20 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
|
||||
// Setup the core
|
||||
c := &Core{
|
||||
advertiseAddr: conf.AdvertiseAddr,
|
||||
physical: conf.Physical,
|
||||
seal: conf.Seal,
|
||||
barrier: barrier,
|
||||
router: NewRouter(),
|
||||
sealed: true,
|
||||
standby: true,
|
||||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
redirectAddr: conf.RedirectAddr,
|
||||
clusterAddr: conf.ClusterAddr,
|
||||
physical: conf.Physical,
|
||||
seal: conf.Seal,
|
||||
barrier: barrier,
|
||||
router: NewRouter(),
|
||||
sealed: true,
|
||||
standby: true,
|
||||
logger: conf.Logger,
|
||||
defaultLeaseTTL: conf.DefaultLeaseTTL,
|
||||
maxLeaseTTL: conf.MaxLeaseTTL,
|
||||
cachingDisabled: conf.DisableCache,
|
||||
clusterName: conf.ClusterName,
|
||||
localClusterCertPool: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
if conf.HAPhysical != nil && conf.HAPhysical.HAEnabled() {
|
||||
|
@ -534,7 +591,18 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
|
|||
|
||||
// Check if we are the leader
|
||||
if !c.standby {
|
||||
return true, c.advertiseAddr, nil
|
||||
// If we have connections from talking to a previous leader, close them
|
||||
// out to free resources
|
||||
if c.requestForwardingConnection != nil {
|
||||
c.requestForwardingConnectionLock.Lock()
|
||||
// Verify that the condition hasn't changed
|
||||
if c.requestForwardingConnection != nil {
|
||||
c.requestForwardingConnection.Transport.(*http.Transport).CloseIdleConnections()
|
||||
}
|
||||
c.requestForwardingConnection = nil
|
||||
c.requestForwardingConnectionLock.Unlock()
|
||||
}
|
||||
return true, c.redirectAddr, nil
|
||||
}
|
||||
|
||||
// Initialize a lock
|
||||
|
@ -562,8 +630,47 @@ func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) {
|
|||
return false, "", nil
|
||||
}
|
||||
|
||||
// Leader address is in the entry
|
||||
return false, string(entry.Value), nil
|
||||
entrySHA256 := sha256.Sum256(entry.Value)
|
||||
|
||||
// Avoid JSON parsing and function calling if nothing has changed
|
||||
if c.clusterActiveAdvertisementHash != nil {
|
||||
if bytes.Compare(entrySHA256[:], c.clusterActiveAdvertisementHash) == 0 {
|
||||
return false, c.clusterActiveAdvertisement.RedirectAddr, nil
|
||||
}
|
||||
}
|
||||
|
||||
var advAddr string
|
||||
var oldAdv bool
|
||||
|
||||
var adv activeAdvertisement
|
||||
err = jsonutil.DecodeJSON(entry.Value, &adv)
|
||||
if err != nil {
|
||||
// Fall back to pre-struct handling
|
||||
advAddr = string(entry.Value)
|
||||
oldAdv = true
|
||||
} else {
|
||||
advAddr = adv.RedirectAddr
|
||||
}
|
||||
|
||||
if !oldAdv {
|
||||
// Ensure we are using current values
|
||||
err = c.loadClusterTLS(adv)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
// This will ensure that we both have a connection at the ready and that
|
||||
// the address is the current known value
|
||||
err = c.refreshRequestForwardingConnection(adv.ClusterAddr)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
}
|
||||
|
||||
c.clusterActiveAdvertisement = adv
|
||||
c.clusterActiveAdvertisementHash = entrySHA256[:]
|
||||
|
||||
return false, advAddr, nil
|
||||
}
|
||||
|
||||
// SecretProgress returns the number of keys provided so far
|
||||
|
@ -660,6 +767,14 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
|||
|
||||
// Do post-unseal setup if HA is not enabled
|
||||
if c.ha == nil {
|
||||
// We still need to set up cluster info even if it's not part of a
|
||||
// cluster right now
|
||||
if err := c.setupCluster(); err != nil {
|
||||
c.logger.Printf("[ERR] core: cluster setup failed: %v", err)
|
||||
c.barrier.Seal()
|
||||
c.logger.Printf("[WARN] core: vault is sealed")
|
||||
return false, err
|
||||
}
|
||||
if err := c.postUnseal(); err != nil {
|
||||
c.logger.Printf("[ERR] core: post-unseal setup failed: %v", err)
|
||||
c.barrier.Seal()
|
||||
|
@ -997,8 +1112,10 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
if err := c.setupAudits(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.setupCluster(); err != nil {
|
||||
return err
|
||||
if c.ha != nil {
|
||||
if err := c.startClusterListener(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.metricsCh = make(chan struct{})
|
||||
go c.emitMetrics(c.metricsCh)
|
||||
|
@ -1023,6 +1140,10 @@ func (c *Core) preSeal() error {
|
|||
c.metricsCh = nil
|
||||
}
|
||||
var result error
|
||||
if c.ha != nil {
|
||||
c.stopClusterListener()
|
||||
}
|
||||
|
||||
if err := c.teardownAudits(); err != nil {
|
||||
result = multierror.Append(result, errwrap.Wrapf("[ERR] error tearing down audits: {{err}}", err))
|
||||
}
|
||||
|
@ -1098,8 +1219,20 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) {
|
|||
// detect flapping
|
||||
activeTime := time.Now()
|
||||
|
||||
// Advertise ourself as leader
|
||||
// Grab the lock as we need it for cluster setup, which needs to happen
|
||||
// before advertising
|
||||
c.stateLock.Lock()
|
||||
if err := c.setupCluster(); err != nil {
|
||||
c.stateLock.Unlock()
|
||||
c.logger.Printf("[ERR] core: cluster setup failed: %v", err)
|
||||
lock.Unlock()
|
||||
metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime)
|
||||
continue
|
||||
}
|
||||
|
||||
// Advertise as leader
|
||||
if err := c.advertiseLeader(uuid, leaderLostCh); err != nil {
|
||||
c.stateLock.Unlock()
|
||||
c.logger.Printf("[ERR] core: leader advertisement setup failed: %v", err)
|
||||
lock.Unlock()
|
||||
metrics.MeasureSince([]string{"core", "leadership_setup_failed"}, activeTime)
|
||||
|
@ -1107,7 +1240,6 @@ func (c *Core) runStandby(doneCh, stopCh, manualStepDownCh chan struct{}) {
|
|||
}
|
||||
|
||||
// Attempt the post-unseal process
|
||||
c.stateLock.Lock()
|
||||
err = c.postUnseal()
|
||||
if err == nil {
|
||||
c.standby = false
|
||||
|
@ -1253,11 +1385,38 @@ func (c *Core) acquireLock(lock physical.Lock, stopCh <-chan struct{}) <-chan st
|
|||
// advertiseLeader is used to advertise the current node as leader
|
||||
func (c *Core) advertiseLeader(uuid string, leaderLostCh <-chan struct{}) error {
|
||||
go c.cleanLeaderPrefix(uuid, leaderLostCh)
|
||||
|
||||
var key *ecdsa.PrivateKey
|
||||
switch c.localClusterPrivateKey.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
key = c.localClusterPrivateKey.(*ecdsa.PrivateKey)
|
||||
default:
|
||||
c.logger.Printf("[ERR] core: unknown cluster private key type %T", c.localClusterPrivateKey)
|
||||
return fmt.Errorf("unknown cluster private key type %T", c.localClusterPrivateKey)
|
||||
}
|
||||
|
||||
keyParams := clusterKeyParams{
|
||||
Type: corePrivateKeyTypeP521,
|
||||
X: key.X,
|
||||
Y: key.Y,
|
||||
D: key.D,
|
||||
}
|
||||
|
||||
adv := &activeAdvertisement{
|
||||
RedirectAddr: c.redirectAddr,
|
||||
ClusterAddr: c.clusterAddr,
|
||||
ClusterCert: c.localClusterCert,
|
||||
ClusterKeyParams: keyParams,
|
||||
}
|
||||
val, err := jsonutil.EncodeJSON(adv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ent := &Entry{
|
||||
Key: coreLeaderPrefix + uuid,
|
||||
Value: []byte(c.advertiseAddr),
|
||||
Value: val,
|
||||
}
|
||||
err := c.barrier.Put(ent)
|
||||
err = c.barrier.Put(ent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1327,3 +1486,7 @@ func (c *Core) SealAccess() *SealAccess {
|
|||
sa.SetSeal(c.seal)
|
||||
return sa
|
||||
}
|
||||
|
||||
func (c *Core) Logger() *log.Logger {
|
||||
return c.logger
|
||||
}
|
||||
|
|
|
@ -19,12 +19,12 @@ var (
|
|||
invalidKey = []byte("abcdefghijklmnopqrstuvwxyz")[:17]
|
||||
)
|
||||
|
||||
func TestNewCore_badAdvertiseAddr(t *testing.T) {
|
||||
func TestNewCore_badRedirectAddr(t *testing.T) {
|
||||
logger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
conf := &CoreConfig{
|
||||
AdvertiseAddr: "127.0.0.1:8200",
|
||||
Physical: physical.NewInmem(logger),
|
||||
DisableMlock: true,
|
||||
RedirectAddr: "127.0.0.1:8200",
|
||||
Physical: physical.NewInmem(logger),
|
||||
DisableMlock: true,
|
||||
}
|
||||
_, err := NewCore(conf)
|
||||
if err == nil {
|
||||
|
@ -46,7 +46,7 @@ func TestSealConfig_Invalid(t *testing.T) {
|
|||
func TestCore_Unseal_MultiShare(t *testing.T) {
|
||||
c := TestCore(t)
|
||||
|
||||
_, err := c.Unseal(invalidKey)
|
||||
_, err := TestCoreUnseal(c, invalidKey)
|
||||
if err != ErrNotInit {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -73,13 +73,13 @@ func TestCore_Unseal_MultiShare(t *testing.T) {
|
|||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
unseal, err := c.Unseal(res.SecretShares[i])
|
||||
unseal, err := TestCoreUnseal(c, res.SecretShares[i])
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Ignore redundant
|
||||
_, err = c.Unseal(res.SecretShares[i])
|
||||
_, err = TestCoreUnseal(c, res.SecretShares[i])
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ func TestCore_Unseal_MultiShare(t *testing.T) {
|
|||
func TestCore_Unseal_Single(t *testing.T) {
|
||||
c := TestCore(t)
|
||||
|
||||
_, err := c.Unseal(invalidKey)
|
||||
_, err := TestCoreUnseal(c, invalidKey)
|
||||
if err != ErrNotInit {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ func TestCore_Unseal_Single(t *testing.T) {
|
|||
t.Fatalf("bad progress: %d", prog)
|
||||
}
|
||||
|
||||
unseal, err := c.Unseal(res.SecretShares[0])
|
||||
unseal, err := TestCoreUnseal(c, res.SecretShares[0])
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ func TestCore_Route_Sealed(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
unseal, err := c.Unseal(res.SecretShares[0])
|
||||
unseal, err := TestCoreUnseal(c, res.SecretShares[0])
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -222,7 +222,7 @@ func TestCore_SealUnseal(t *testing.T) {
|
|||
if err := c.Seal(root); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if unseal, err := c.Unseal(key); err != nil || !unseal {
|
||||
if unseal, err := TestCoreUnseal(c, key); err != nil || !unseal {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -958,18 +958,18 @@ func TestCore_Standby_Seal(t *testing.T) {
|
|||
logger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
inm := physical.NewInmem(logger)
|
||||
inmha := physical.NewInmemHA(logger)
|
||||
advertiseOriginal := "http://127.0.0.1:8200"
|
||||
redirectOriginal := "http://127.0.0.1:8200"
|
||||
core, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, root := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -983,7 +983,7 @@ func TestCore_Standby_Seal(t *testing.T) {
|
|||
}
|
||||
|
||||
// Wait for core to become active
|
||||
testWaitActive(t, core)
|
||||
TestWaitActive(t, core)
|
||||
|
||||
// Check the leader is local
|
||||
isLeader, advertise, err := core.Leader()
|
||||
|
@ -993,22 +993,22 @@ func TestCore_Standby_Seal(t *testing.T) {
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
// Create the second core and initialize it
|
||||
advertiseOriginal2 := "http://127.0.0.1:8500"
|
||||
redirectOriginal2 := "http://127.0.0.1:8500"
|
||||
core2, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal2,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal2,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := core2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core2, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1038,7 +1038,7 @@ func TestCore_Standby_Seal(t *testing.T) {
|
|||
if isLeader {
|
||||
t.Fatalf("should not be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1064,18 +1064,18 @@ func TestCore_StepDown(t *testing.T) {
|
|||
logger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
inm := physical.NewInmem(logger)
|
||||
inmha := physical.NewInmemHA(logger)
|
||||
advertiseOriginal := "http://127.0.0.1:8200"
|
||||
redirectOriginal := "http://127.0.0.1:8200"
|
||||
core, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, root := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1089,7 +1089,7 @@ func TestCore_StepDown(t *testing.T) {
|
|||
}
|
||||
|
||||
// Wait for core to become active
|
||||
testWaitActive(t, core)
|
||||
TestWaitActive(t, core)
|
||||
|
||||
// Check the leader is local
|
||||
isLeader, advertise, err := core.Leader()
|
||||
|
@ -1099,22 +1099,22 @@ func TestCore_StepDown(t *testing.T) {
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
// Create the second core and initialize it
|
||||
advertiseOriginal2 := "http://127.0.0.1:8500"
|
||||
redirectOriginal2 := "http://127.0.0.1:8500"
|
||||
core2, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal2,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal2,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := core2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core2, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1144,7 +1144,7 @@ func TestCore_StepDown(t *testing.T) {
|
|||
if isLeader {
|
||||
t.Fatalf("should not be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1185,7 +1185,7 @@ func TestCore_StepDown(t *testing.T) {
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal2 {
|
||||
if advertise != redirectOriginal2 {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1197,7 +1197,7 @@ func TestCore_StepDown(t *testing.T) {
|
|||
if isLeader {
|
||||
t.Fatalf("should not be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal2 {
|
||||
if advertise != redirectOriginal2 {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1228,7 +1228,7 @@ func TestCore_StepDown(t *testing.T) {
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1240,7 +1240,7 @@ func TestCore_StepDown(t *testing.T) {
|
|||
if isLeader {
|
||||
t.Fatalf("should not be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
}
|
||||
|
@ -1250,18 +1250,18 @@ func TestCore_CleanLeaderPrefix(t *testing.T) {
|
|||
logger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
inm := physical.NewInmem(logger)
|
||||
inmha := physical.NewInmemHA(logger)
|
||||
advertiseOriginal := "http://127.0.0.1:8200"
|
||||
redirectOriginal := "http://127.0.0.1:8200"
|
||||
core, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, root := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1275,7 +1275,7 @@ func TestCore_CleanLeaderPrefix(t *testing.T) {
|
|||
}
|
||||
|
||||
// Wait for core to become active
|
||||
testWaitActive(t, core)
|
||||
TestWaitActive(t, core)
|
||||
|
||||
// Ensure that the original clean function has stopped running
|
||||
time.Sleep(2 * time.Second)
|
||||
|
@ -1312,22 +1312,22 @@ func TestCore_CleanLeaderPrefix(t *testing.T) {
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
// Create a second core, attached to same in-memory store
|
||||
advertiseOriginal2 := "http://127.0.0.1:8500"
|
||||
redirectOriginal2 := "http://127.0.0.1:8500"
|
||||
core2, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal2,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal2,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := core2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core2, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1357,7 +1357,7 @@ func TestCore_CleanLeaderPrefix(t *testing.T) {
|
|||
if isLeader {
|
||||
t.Fatalf("should not be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1377,7 +1377,7 @@ func TestCore_CleanLeaderPrefix(t *testing.T) {
|
|||
}
|
||||
|
||||
// Wait for core2 to become active
|
||||
testWaitActive(t, core2)
|
||||
TestWaitActive(t, core2)
|
||||
|
||||
// Check the leader is local
|
||||
isLeader, advertise, err = core2.Leader()
|
||||
|
@ -1387,7 +1387,7 @@ func TestCore_CleanLeaderPrefix(t *testing.T) {
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal2 {
|
||||
if advertise != redirectOriginal2 {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1416,18 +1416,18 @@ func TestCore_Standby_SeparateHA(t *testing.T) {
|
|||
|
||||
func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.HABackend) {
|
||||
// Create the first core and initialize it
|
||||
advertiseOriginal := "http://127.0.0.1:8200"
|
||||
redirectOriginal := "http://127.0.0.1:8200"
|
||||
core, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, root := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1441,7 +1441,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
|||
}
|
||||
|
||||
// Wait for core to become active
|
||||
testWaitActive(t, core)
|
||||
TestWaitActive(t, core)
|
||||
|
||||
// Put a secret
|
||||
req := &logical.Request{
|
||||
|
@ -1465,22 +1465,22 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
// Create a second core, attached to same in-memory store
|
||||
advertiseOriginal2 := "http://127.0.0.1:8500"
|
||||
redirectOriginal2 := "http://127.0.0.1:8500"
|
||||
core2, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal2,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal2,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := core2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core2, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -1516,7 +1516,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
|||
if isLeader {
|
||||
t.Fatalf("should not be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal {
|
||||
if advertise != redirectOriginal {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1536,7 +1536,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
|||
}
|
||||
|
||||
// Wait for core2 to become active
|
||||
testWaitActive(t, core2)
|
||||
TestWaitActive(t, core2)
|
||||
|
||||
// Read the secret
|
||||
req = &logical.Request{
|
||||
|
@ -1562,7 +1562,7 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical.
|
|||
if !isLeader {
|
||||
t.Fatalf("should be leader")
|
||||
}
|
||||
if advertise != advertiseOriginal2 {
|
||||
if advertise != redirectOriginal2 {
|
||||
t.Fatalf("Bad advertise: %v", advertise)
|
||||
}
|
||||
|
||||
|
@ -1948,59 +1948,41 @@ func TestCore_HandleRequest_MountPoint(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testWaitActive(t *testing.T, core *Core) {
|
||||
start := time.Now()
|
||||
var standby bool
|
||||
var err error
|
||||
for time.Now().Sub(start) < time.Second {
|
||||
standby, err = core.Standby()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !standby {
|
||||
break
|
||||
}
|
||||
}
|
||||
if standby {
|
||||
t.Fatalf("should not be in standby mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCore_Standby_Rotate(t *testing.T) {
|
||||
// Create the first core and initialize it
|
||||
logger = log.New(os.Stderr, "", log.LstdFlags)
|
||||
inm := physical.NewInmem(logger)
|
||||
inmha := physical.NewInmemHA(logger)
|
||||
advertiseOriginal := "http://127.0.0.1:8200"
|
||||
redirectOriginal := "http://127.0.0.1:8200"
|
||||
core, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, root := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
// Wait for core to become active
|
||||
testWaitActive(t, core)
|
||||
TestWaitActive(t, core)
|
||||
|
||||
// Create a second core, attached to same in-memory store
|
||||
advertiseOriginal2 := "http://127.0.0.1:8500"
|
||||
redirectOriginal2 := "http://127.0.0.1:8500"
|
||||
core2, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal2,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal2,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := core2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core2, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -2022,7 +2004,7 @@ func TestCore_Standby_Rotate(t *testing.T) {
|
|||
}
|
||||
|
||||
// Wait for core2 to become active
|
||||
testWaitActive(t, core2)
|
||||
TestWaitActive(t, core2)
|
||||
|
||||
// Read the key status
|
||||
req = &logical.Request{
|
||||
|
|
|
@ -173,8 +173,13 @@ func (c *Core) Initialize(barrierConfig, recoveryConfig *SealConfig) (*InitResul
|
|||
}()
|
||||
|
||||
// Perform initial setup
|
||||
if err := c.setupCluster(); err != nil {
|
||||
c.stateLock.Unlock()
|
||||
c.logger.Printf("[ERR] core: cluster setup failed during init: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if err := c.postUnseal(); err != nil {
|
||||
c.logger.Printf("[ERR] core: post-unseal setup failed: %v", err)
|
||||
c.logger.Printf("[ERR] core: post-unseal setup failed during init: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func TestCore_DefaultMountTable(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func TestCore_Mount(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -98,7 +98,7 @@ func TestCore_Unmount(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ func TestCore_Remount(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c2.Unseal(key)
|
||||
unseal, err := TestCoreUnseal(c2, key)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
|
|
@ -225,7 +225,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err = c.Unseal(result.SecretShares[i])
|
||||
_, err = TestCoreUnseal(c, result.SecretShares[i])
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -282,7 +282,7 @@ func testCore_Rekey_Update_Common(t *testing.T, c *Core, keys [][]byte, root str
|
|||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
unseal, err := c.Unseal(result.SecretShares[0])
|
||||
unseal, err := TestCoreUnseal(c, result.SecretShares[0])
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -364,36 +364,36 @@ func TestCore_Standby_Rekey(t *testing.T) {
|
|||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
inm := physical.NewInmem(logger)
|
||||
inmha := physical.NewInmemHA(logger)
|
||||
advertiseOriginal := "http://127.0.0.1:8200"
|
||||
redirectOriginal := "http://127.0.0.1:8200"
|
||||
core, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
key, root := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
// Wait for core to become active
|
||||
testWaitActive(t, core)
|
||||
TestWaitActive(t, core)
|
||||
|
||||
// Create a second core, attached to same in-memory store
|
||||
advertiseOriginal2 := "http://127.0.0.1:8500"
|
||||
redirectOriginal2 := "http://127.0.0.1:8500"
|
||||
core2, err := NewCore(&CoreConfig{
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
AdvertiseAddr: advertiseOriginal2,
|
||||
DisableMlock: true,
|
||||
Physical: inm,
|
||||
HAPhysical: inmha,
|
||||
RedirectAddr: redirectOriginal2,
|
||||
DisableMlock: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := core2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core2, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -429,7 +429,7 @@ func TestCore_Standby_Rekey(t *testing.T) {
|
|||
}
|
||||
|
||||
// Wait for core2 to become active
|
||||
testWaitActive(t, core2)
|
||||
TestWaitActive(t, core2)
|
||||
|
||||
// Rekey the master key again
|
||||
err = core2.RekeyInit(newConf, false)
|
||||
|
|
456
vault/testing.go
456
vault/testing.go
|
@ -4,9 +4,13 @@ import (
|
|||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
@ -131,6 +135,11 @@ func TestCoreWithSeal(t *testing.T, testSeal Seal) *Core {
|
|||
// TestCoreInit initializes the core with a single key, and returns
|
||||
// the key that must be used to unseal the core and a root token.
|
||||
func TestCoreInit(t *testing.T, core *Core) ([]byte, string) {
|
||||
return TestCoreInitClusterListenerSetup(t, core, func() ([]net.Listener, http.Handler, error) { return nil, nil, nil })
|
||||
}
|
||||
|
||||
func TestCoreInitClusterListenerSetup(t *testing.T, core *Core, setupFunc func() ([]net.Listener, http.Handler, error)) ([]byte, string) {
|
||||
core.SetClusterListenerSetupFunc(setupFunc)
|
||||
result, err := core.Initialize(&SealConfig{
|
||||
SecretShares: 1,
|
||||
SecretThreshold: 1,
|
||||
|
@ -141,12 +150,17 @@ func TestCoreInit(t *testing.T, core *Core) ([]byte, string) {
|
|||
return result.SecretShares[0], result.RootToken
|
||||
}
|
||||
|
||||
func TestCoreUnseal(core *Core, key []byte) (bool, error) {
|
||||
core.SetClusterListenerSetupFunc(func() ([]net.Listener, http.Handler, error) { return nil, nil, nil })
|
||||
return core.Unseal(key)
|
||||
}
|
||||
|
||||
// TestCoreUnsealed returns a pure in-memory core that is already
|
||||
// initialized and unsealed.
|
||||
func TestCoreUnsealed(t *testing.T) (*Core, []byte, string) {
|
||||
core := TestCore(t)
|
||||
key, token := TestCoreInit(t, core)
|
||||
if _, err := core.Unseal(TestKeyCopy(key)); err != nil {
|
||||
if _, err := TestCoreUnseal(core, TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
|
@ -377,3 +391,443 @@ func GenerateRandBytes(length int) ([]byte, error) {
|
|||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func TestWaitActive(t *testing.T, core *Core) {
|
||||
start := time.Now()
|
||||
var standby bool
|
||||
var err error
|
||||
for time.Now().Sub(start) < time.Second {
|
||||
standby, err = core.Standby()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !standby {
|
||||
break
|
||||
}
|
||||
}
|
||||
if standby {
|
||||
t.Fatalf("should not be in standby mode")
|
||||
}
|
||||
}
|
||||
|
||||
type TestListener struct {
|
||||
net.Listener
|
||||
Address *net.TCPAddr
|
||||
}
|
||||
|
||||
type TestClusterCore struct {
|
||||
*Core
|
||||
Listeners []*TestListener
|
||||
Root string
|
||||
Key []byte
|
||||
CACertBytes []byte
|
||||
CACert *x509.Certificate
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
func (t *TestClusterCore) CloseListeners() {
|
||||
if t.Listeners != nil {
|
||||
for _, ln := range t.Listeners {
|
||||
ln.Close()
|
||||
}
|
||||
}
|
||||
// Give time to actually shut down/clean up before the next test
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unsealStandbys bool) []*TestClusterCore {
|
||||
if handlers == nil || len(handlers) != 3 {
|
||||
t.Fatal("handlers must be size 3")
|
||||
}
|
||||
|
||||
//
|
||||
// TLS setup
|
||||
//
|
||||
block, _ := pem.Decode([]byte(TestClusterCACert))
|
||||
if block == nil {
|
||||
t.Fatal("error decoding cluster CA cert")
|
||||
}
|
||||
caBytes := block.Bytes
|
||||
caCert, err := x509.ParseCertificate(caBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverCert, err := tls.X509KeyPair([]byte(TestClusterServerCert), []byte(TestClusterServerKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AppendCertsFromPEM([]byte(TestClusterCACert))
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
RootCAs: rootCAs,
|
||||
ClientCAs: rootCAs,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
|
||||
// Sanity checking
|
||||
block, _ = pem.Decode([]byte(TestClusterServerCert))
|
||||
if block == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedServerCert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
chains, err := parsedServerCert.Verify(x509.VerifyOptions{
|
||||
DNSName: "127.0.0.1",
|
||||
Roots: rootCAs,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if chains == nil || len(chains) == 0 {
|
||||
t.Fatal("no verified chains for server auth")
|
||||
}
|
||||
chains, err = parsedServerCert.Verify(x509.VerifyOptions{
|
||||
DNSName: "127.0.0.1",
|
||||
Roots: rootCAs,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if chains == nil || len(chains) == 0 {
|
||||
t.Fatal("no verified chains for chains auth")
|
||||
}
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
|
||||
//
|
||||
// Listener setup
|
||||
//
|
||||
ln, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c1lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, tlsConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
ln, err = net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c1lns = append(c1lns, &TestListener{
|
||||
Listener: tls.NewListener(ln, tlsConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
})
|
||||
server1 := &http.Server{
|
||||
Handler: handlers[0],
|
||||
}
|
||||
for _, ln := range c1lns {
|
||||
go server1.Serve(ln)
|
||||
}
|
||||
|
||||
ln, err = net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c2lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, tlsConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
server2 := &http.Server{
|
||||
Handler: handlers[1],
|
||||
}
|
||||
for _, ln := range c2lns {
|
||||
go server2.Serve(ln)
|
||||
}
|
||||
|
||||
ln, err = net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c3lns := []*TestListener{&TestListener{
|
||||
Listener: tls.NewListener(ln, tlsConfig),
|
||||
Address: ln.Addr().(*net.TCPAddr),
|
||||
},
|
||||
}
|
||||
server3 := &http.Server{
|
||||
Handler: handlers[2],
|
||||
}
|
||||
for _, ln := range c3lns {
|
||||
go server3.Serve(ln)
|
||||
}
|
||||
|
||||
// Create three cores with the same physical and different redirect/cluster addrs
|
||||
coreConfig := &CoreConfig{
|
||||
Physical: physical.NewInmem(logger),
|
||||
HAPhysical: physical.NewInmemHA(logger),
|
||||
LogicalBackends: make(map[string]logical.Factory),
|
||||
CredentialBackends: make(map[string]logical.Factory),
|
||||
AuditBackends: make(map[string]audit.Factory),
|
||||
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port),
|
||||
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+1),
|
||||
DisableMlock: true,
|
||||
}
|
||||
|
||||
coreConfig.LogicalBackends["generic"] = PassthroughBackendFactory
|
||||
|
||||
if base != nil {
|
||||
// Used to set something non-working to test fallback
|
||||
switch base.ClusterAddr {
|
||||
case "empty":
|
||||
coreConfig.ClusterAddr = ""
|
||||
case "":
|
||||
default:
|
||||
coreConfig.ClusterAddr = base.ClusterAddr
|
||||
}
|
||||
|
||||
if base.LogicalBackends != nil {
|
||||
for k, v := range base.LogicalBackends {
|
||||
coreConfig.LogicalBackends[k] = v
|
||||
}
|
||||
}
|
||||
if base.CredentialBackends != nil {
|
||||
for k, v := range base.CredentialBackends {
|
||||
coreConfig.CredentialBackends[k] = v
|
||||
}
|
||||
}
|
||||
if base.AuditBackends != nil {
|
||||
for k, v := range base.AuditBackends {
|
||||
coreConfig.AuditBackends[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c1, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+1)
|
||||
}
|
||||
c2, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
|
||||
if coreConfig.ClusterAddr != "" {
|
||||
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+1)
|
||||
}
|
||||
c3, err := NewCore(coreConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
//
|
||||
// Clustering setup
|
||||
//
|
||||
clusterAddrGen := func(lns []*TestListener) []string {
|
||||
ret := make([]string, len(lns))
|
||||
for i, ln := range lns {
|
||||
curAddr := ln.Address
|
||||
ipStr := curAddr.IP.String()
|
||||
if len(curAddr.IP) == net.IPv6len {
|
||||
ipStr = fmt.Sprintf("[%s]", ipStr)
|
||||
}
|
||||
ret[i] = fmt.Sprintf("%s:%d", ipStr, curAddr.Port+1)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
c2.SetClusterListenerSetupFunc(WrapListenersForClustering(clusterAddrGen(c2lns), handlers[1], logger))
|
||||
c3.SetClusterListenerSetupFunc(WrapListenersForClustering(clusterAddrGen(c3lns), handlers[2], logger))
|
||||
key, root := TestCoreInitClusterListenerSetup(t, c1, WrapListenersForClustering(clusterAddrGen(c1lns), handlers[0], logger))
|
||||
if _, err := c1.Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
// Verify unsealed
|
||||
sealed, err := c1.Sealed()
|
||||
if err != nil {
|
||||
t.Fatalf("err checking seal status: %s", err)
|
||||
}
|
||||
if sealed {
|
||||
t.Fatal("should not be sealed")
|
||||
}
|
||||
|
||||
TestWaitActive(t, c1)
|
||||
|
||||
if unsealStandbys {
|
||||
if _, err := c2.Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
if _, err := c3.Unseal(TestKeyCopy(key)); err != nil {
|
||||
t.Fatalf("unseal err: %s", err)
|
||||
}
|
||||
|
||||
// Let them come fully up to standby
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Ensure cluster connection info is populated
|
||||
isLeader, _, err := c2.Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatal("c2 should not be leader")
|
||||
}
|
||||
isLeader, _, err = c3.Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatal("c3 should not be leader")
|
||||
}
|
||||
}
|
||||
|
||||
return []*TestClusterCore{
|
||||
&TestClusterCore{
|
||||
Core: c1,
|
||||
Listeners: c1lns,
|
||||
Root: root,
|
||||
Key: TestKeyCopy(key),
|
||||
CACertBytes: caBytes,
|
||||
CACert: caCert,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
&TestClusterCore{
|
||||
Core: c2,
|
||||
Listeners: c2lns,
|
||||
Root: root,
|
||||
Key: TestKeyCopy(key),
|
||||
CACertBytes: caBytes,
|
||||
CACert: caCert,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
&TestClusterCore{
|
||||
Core: c3,
|
||||
Listeners: c3lns,
|
||||
Root: root,
|
||||
Key: TestKeyCopy(key),
|
||||
CACertBytes: caBytes,
|
||||
CACert: caCert,
|
||||
TLSConfig: tlsConfig,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
TestClusterCACert = `-----BEGIN CERTIFICATE-----
|
||||
MIIDPjCCAiagAwIBAgIUfIKsF2VPT7sdFcKOHJH2Ii6K4MwwDQYJKoZIhvcNAQEL
|
||||
BQAwFjEUMBIGA1UEAxMLbXl2YXVsdC5jb20wIBcNMTYwNTAyMTYwNTQyWhgPMjA2
|
||||
NjA0MjAxNjA2MTJaMBYxFDASBgNVBAMTC215dmF1bHQuY29tMIIBIjANBgkqhkiG
|
||||
9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuOimEXawD2qBoLCFP3Skq5zi1XzzcMAJlfdS
|
||||
xz9hfymuJb+cN8rB91HOdU9wQCwVKnkUtGWxUnMp0tT0uAZj5NzhNfyinf0JGAbP
|
||||
67HDzVZhGBHlHTjPX0638yaiUx90cTnucX0N20SgCYct29dMSgcPl+W78D3Jw3xE
|
||||
JsHQPYS9ASe2eONxG09F/qNw7w/RO5/6WYoV2EmdarMMxq52pPe2chtNMQdSyOUb
|
||||
cCcIZyk4QVFZ1ZLl6jTnUPb+JoCx1uMxXvMek4NF/5IL0Wr9dw2gKXKVKoHDr6SY
|
||||
WrCONRw61A5Zwx1V+kn73YX3USRlkufQv/ih6/xThYDAXDC9cwIDAQABo4GBMH8w
|
||||
DgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFOuKvPiU
|
||||
G06iHkRXAOeMiUdBfHFyMB8GA1UdIwQYMBaAFOuKvPiUG06iHkRXAOeMiUdBfHFy
|
||||
MBwGA1UdEQQVMBOCC215dmF1bHQuY29thwR/AAABMA0GCSqGSIb3DQEBCwUAA4IB
|
||||
AQBcN/UdAMzc7UjRdnIpZvO+5keBGhL/vjltnGM1dMWYHa60Y5oh7UIXF+P1RdNW
|
||||
n7g80lOyvkSR15/r1rDkqOK8/4oruXU31EcwGhDOC4hU6yMUy4ltV/nBoodHBXNh
|
||||
MfKiXeOstH1vdI6G0P6W93Bcww6RyV1KH6sT2dbETCw+iq2VN9CrruGIWzd67UT/
|
||||
spe/kYttr3UYVV3O9kqgffVVgVXg/JoRZ3J7Hy2UEXfh9UtWNanDlRuXaZgE9s/d
|
||||
CpA30CHpNXvKeyNeW2ktv+2nAbSpvNW+e6MecBCTBIoDSkgU8ShbrzmDKVwNN66Q
|
||||
5gn6KxUPBKHEtNzs5DgGM7nq
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
TestClusterCAKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAuOimEXawD2qBoLCFP3Skq5zi1XzzcMAJlfdSxz9hfymuJb+c
|
||||
N8rB91HOdU9wQCwVKnkUtGWxUnMp0tT0uAZj5NzhNfyinf0JGAbP67HDzVZhGBHl
|
||||
HTjPX0638yaiUx90cTnucX0N20SgCYct29dMSgcPl+W78D3Jw3xEJsHQPYS9ASe2
|
||||
eONxG09F/qNw7w/RO5/6WYoV2EmdarMMxq52pPe2chtNMQdSyOUbcCcIZyk4QVFZ
|
||||
1ZLl6jTnUPb+JoCx1uMxXvMek4NF/5IL0Wr9dw2gKXKVKoHDr6SYWrCONRw61A5Z
|
||||
wx1V+kn73YX3USRlkufQv/ih6/xThYDAXDC9cwIDAQABAoIBAG3bCo7ljMQb6tel
|
||||
CAUjL5Ilqz5a9ebOsONABRYLOclq4ePbatxawdJF7/sSLwZxKkIJnZtvr2Hkubxg
|
||||
eOO8KC0YbVS9u39Rjc2QfobxHfsojpbWSuCJl+pvwinbkiUAUxXR7S/PtCPJKat/
|
||||
fGdYCiMQ/tqnynh4vR4+/d5o12c0KuuQ22/MdEf3GOadUamRXS1ET9iJWqla1pJW
|
||||
TmzrlkGAEnR5PPO2RMxbnZCYmj3dArxWAnB57W+bWYla0DstkDKtwg2j2ikNZpXB
|
||||
nkZJJpxR76IYD1GxfwftqAKxujKcyfqB0dIKCJ0UmfOkauNWjexroNLwaAOC3Nud
|
||||
XIxppAECgYEA1wJ9EH6A6CrSjdzUocF9LtQy1LCDHbdiQFHxM5/zZqIxraJZ8Gzh
|
||||
Q0d8JeOjwPdG4zL9pHcWS7+x64Wmfn0+Qfh6/47Vy3v90PIL0AeZYshrVZyJ/s6X
|
||||
YkgFK80KEuWtacqIZ1K2UJyCw81u/ynIl2doRsIbgkbNeN0opjmqVTMCgYEA3CkW
|
||||
2fETWK1LvmgKFjG1TjOotVRIOUfy4iN0kznPm6DK2PgTF5DX5RfktlmA8i8WPmB7
|
||||
YFOEdAWHf+RtoM/URa7EAGZncCWe6uggAcWqznTS619BJ63OmncpSWov5Byg90gJ
|
||||
48qIMY4wDjE85ypz1bmBc2Iph974dtWeDtB7dsECgYAyKZh4EquMfwEkq9LH8lZ8
|
||||
aHF7gbr1YeWAUB3QB49H8KtacTg+iYh8o97pEBUSXh6hvzHB/y6qeYzPAB16AUpX
|
||||
Jdu8Z9ylXsY2y2HKJRu6GjxAewcO9bAH8/mQ4INrKT6uIdx1Dq0OXZV8jR9KVLtB
|
||||
55RCfeLhIBesDR0Auw9sVQKBgB0xTZhkgP43LF35Ca1btgDClNJGdLUztx8JOIH1
|
||||
HnQyY/NVIaL0T8xO2MLdJ131pGts+68QI/YGbaslrOuv4yPCQrcS3RBfzKy1Ttkt
|
||||
TrLFhtoy7T7HqyeMOWtEq0kCCs3/PWB5EIoRoomfOcYlOOrUCDg2ge9EP4nyVVz9
|
||||
hAGBAoGBAJXw/ufevxpBJJMSyULmVWYr34GwLC1OhSE6AVVt9JkIYnc5L4xBKTHP
|
||||
QNKKJLmFmMsEqfxHUNWmpiHkm2E0p37Zehui3kywo+A4ybHPTua70ZWQfZhKxLUr
|
||||
PvJa8JmwiCM7kO8zjOv+edY1mMWrbjAZH1YUbfcTHmST7S8vp0F3
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
TestClusterServerCert = `-----BEGIN CERTIFICATE-----
|
||||
MIIDtzCCAp+gAwIBAgIUBLqh6ctGWVDUxFhxJX7m6S/bnrcwDQYJKoZIhvcNAQEL
|
||||
BQAwFjEUMBIGA1UEAxMLbXl2YXVsdC5jb20wIBcNMTYwNTAyMTYwOTI2WhgPMjA2
|
||||
NjA0MjAxNTA5NTZaMBsxGTAXBgNVBAMTEGNlcnQubXl2YXVsdC5jb20wggEiMA0G
|
||||
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDY3gPB29kkdbu0mPO6J0efagQhSiXB
|
||||
9OyDuLf5sMk6CVDWVWal5hISkyBmw/lXgF7qC2XFKivpJOrcGQd5Ep9otBqyJLzI
|
||||
b0IWdXuPIrVnXDwcdWr86ybX2iC42zKWfbXgjzGijeAVpl0UJLKBj+fk5q6NvkRL
|
||||
5FUL6TRV7Krn9mrmnrV9J5IqV15pTd9W2aVJ6IqWvIPCACtZKulqWn4707uy2X2W
|
||||
1Stq/5qnp1pDshiGk1VPyxCwQ6yw3iEcgecbYo3vQfhWcv7Q8LpSIM9ZYpXu6OmF
|
||||
+czqRZS9gERl+wipmmrN1MdYVrTuQem21C/PNZ4jo4XUk1SFx6JrcA+lAgMBAAGj
|
||||
gfUwgfIwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSe
|
||||
Cl9WV3BjGCwmS/KrDSLRjfwyqjAfBgNVHSMEGDAWgBTrirz4lBtOoh5EVwDnjIlH
|
||||
QXxxcjA7BggrBgEFBQcBAQQvMC0wKwYIKwYBBQUHMAKGH2h0dHA6Ly8xMjcuMC4w
|
||||
LjE6ODIwMC92MS9wa2kvY2EwIQYDVR0RBBowGIIQY2VydC5teXZhdWx0LmNvbYcE
|
||||
fwAAATAxBgNVHR8EKjAoMCagJKAihiBodHRwOi8vMTI3LjAuMC4xOjgyMDAvdjEv
|
||||
cGtpL2NybDANBgkqhkiG9w0BAQsFAAOCAQEAWGholPN8buDYwKbUiDavbzjsxUIX
|
||||
lU4MxEqOHw7CD3qIYIauPboLvB9EldBQwhgOOy607Yvdg3rtyYwyBFwPhHo/hK3Z
|
||||
6mn4hc6TF2V+AUdHBvGzp2dbYLeo8noVoWbQ/lBulggwlIHNNF6+a3kALqsqk1Ch
|
||||
f/hzsjFnDhAlNcYFgG8TgfE2lE/FckvejPqBffo7Q3I+wVAw0buqiz5QL81NOT+D
|
||||
Y2S9LLKLRaCsWo9wRU1Az4Rhd7vK5SEMh16jJ82GyEODWPvuxOTI1MnzfnbWyLYe
|
||||
TTp6YBjGMVf1I6NEcWNur7U17uIOiQjMZ9krNvoMJ1A/cxCoZ98QHgcIPg==
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
TestClusterServerKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEA2N4DwdvZJHW7tJjzuidHn2oEIUolwfTsg7i3+bDJOglQ1lVm
|
||||
peYSEpMgZsP5V4Be6gtlxSor6STq3BkHeRKfaLQasiS8yG9CFnV7jyK1Z1w8HHVq
|
||||
/Osm19oguNsyln214I8xoo3gFaZdFCSygY/n5Oaujb5ES+RVC+k0Veyq5/Zq5p61
|
||||
fSeSKldeaU3fVtmlSeiKlryDwgArWSrpalp+O9O7stl9ltUrav+ap6daQ7IYhpNV
|
||||
T8sQsEOssN4hHIHnG2KN70H4VnL+0PC6UiDPWWKV7ujphfnM6kWUvYBEZfsIqZpq
|
||||
zdTHWFa07kHpttQvzzWeI6OF1JNUhceia3APpQIDAQABAoIBAQCH3vEzr+3nreug
|
||||
RoPNCXcSJXXY9X+aeT0FeeGqClzIg7Wl03OwVOjVwl/2gqnhbIgK0oE8eiNwurR6
|
||||
mSPZcxV0oAJpwiKU4T/imlCDaReGXn86xUX2l82KRxthNdQH/VLKEmzij0jpx4Vh
|
||||
bWx5SBPdkbmjDKX1dmTiRYWIn/KjyNPvNvmtwdi8Qluhf4eJcNEUr2BtblnGOmfL
|
||||
FdSu+brPJozpoQ1QdDnbAQRgqnh7Shl0tT85whQi0uquqIj1gEOGVjmBvDDnL3GV
|
||||
WOENTKqsmIIoEzdZrql1pfmYTk7WNaD92bfpN128j8BF7RmAV4/DphH0pvK05y9m
|
||||
tmRhyHGxAoGBAOV2BBocsm6xup575VqmFN+EnIOiTn+haOvfdnVsyQHnth63fOQx
|
||||
PNtMpTPR1OMKGpJ13e2bV0IgcYRsRkScVkUtoa/17VIgqZXffnJJ0A/HT67uKBq3
|
||||
8o7RrtyK5N20otw0lZHyqOPhyCdpSsurDhNON1kPVJVYY4N1RiIxfut/AoGBAPHz
|
||||
HfsJ5ZkyELE9N/r4fce04lprxWH+mQGK0/PfjS9caXPhj/r5ZkVMvzWesF3mmnY8
|
||||
goE5S35TuTvV1+6rKGizwlCFAQlyXJiFpOryNWpLwCmDDSzLcm+sToAlML3tMgWU
|
||||
jM3dWHx3C93c3ft4rSWJaUYI9JbHsMzDW6Yh+GbbAoGBANIbKwxh5Hx5XwEJP2yu
|
||||
kIROYCYkMy6otHLujgBdmPyWl+suZjxoXWoMl2SIqR8vPD+Jj6mmyNJy9J6lqf3f
|
||||
DRuQ+fEuBZ1i7QWfvJ+XuN0JyovJ5Iz6jC58D1pAD+p2IX3y5FXcVQs8zVJRFjzB
|
||||
p0TEJOf2oqORaKWRd6ONoMKvAoGALKu6aVMWdQZtVov6/fdLIcgf0pn7Q3CCR2qe
|
||||
X3Ry2L+zKJYIw0mwvDLDSt8VqQCenB3n6nvtmFFU7ds5lvM67rnhsoQcAOaAehiS
|
||||
rl4xxoJd5Ewx7odRhZTGmZpEOYzFo4odxRSM9c30/u18fqV1Mm0AZtHYds4/sk6P
|
||||
aUj0V+kCgYBMpGrJk8RSez5g0XZ35HfpI4ENoWbiwB59FIpWsLl2LADEh29eC455
|
||||
t9Muq7MprBVBHQo11TMLLFxDIjkuMho/gcKgpYXCt0LfiNm8EZehvLJUXH+3WqUx
|
||||
we6ywrbFCs6LaxaOCtTiLsN+GbZCatITL0UJaeBmTAbiw0KQjUuZPQ==
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
#
|
||||
# This Dockerfile builds a recent curl with HTTP/2 client support, using
|
||||
# a recent nghttp2 build.
|
||||
#
|
||||
# See the Makefile for how to tag it. If Docker and that image is found, the
|
||||
# Go tests use this curl binary for integration tests.
|
||||
#
|
||||
|
||||
FROM ubuntu:trusty
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get upgrade -y && \
|
||||
apt-get install -y git-core build-essential wget
|
||||
|
||||
RUN apt-get install -y --no-install-recommends \
|
||||
autotools-dev libtool pkg-config zlib1g-dev \
|
||||
libcunit1-dev libssl-dev libxml2-dev libevent-dev \
|
||||
automake autoconf
|
||||
|
||||
# The list of packages nghttp2 recommends for h2load:
|
||||
RUN apt-get install -y --no-install-recommends make binutils \
|
||||
autoconf automake autotools-dev \
|
||||
libtool pkg-config zlib1g-dev libcunit1-dev libssl-dev libxml2-dev \
|
||||
libev-dev libevent-dev libjansson-dev libjemalloc-dev \
|
||||
cython python3.4-dev python-setuptools
|
||||
|
||||
# Note: setting NGHTTP2_VER before the git clone, so an old git clone isn't cached:
|
||||
ENV NGHTTP2_VER 895da9a
|
||||
RUN cd /root && git clone https://github.com/tatsuhiro-t/nghttp2.git
|
||||
|
||||
WORKDIR /root/nghttp2
|
||||
RUN git reset --hard $NGHTTP2_VER
|
||||
RUN autoreconf -i
|
||||
RUN automake
|
||||
RUN autoconf
|
||||
RUN ./configure
|
||||
RUN make
|
||||
RUN make install
|
||||
|
||||
WORKDIR /root
|
||||
RUN wget http://curl.haxx.se/download/curl-7.45.0.tar.gz
|
||||
RUN tar -zxvf curl-7.45.0.tar.gz
|
||||
WORKDIR /root/curl-7.45.0
|
||||
RUN ./configure --with-ssl --with-nghttp2=/usr/local
|
||||
RUN make
|
||||
RUN make install
|
||||
RUN ldconfig
|
||||
|
||||
CMD ["-h"]
|
||||
ENTRYPOINT ["/usr/local/bin/curl"]
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
curlimage:
|
||||
docker build -t gohttp2/curl .
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
This is a work-in-progress HTTP/2 implementation for Go.
|
||||
|
||||
It will eventually live in the Go standard library and won't require
|
||||
any changes to your code to use. It will just be automatic.
|
||||
|
||||
Status:
|
||||
|
||||
* The server support is pretty good. A few things are missing
|
||||
but are being worked on.
|
||||
* The client work has just started but shares a lot of code
|
||||
is coming along much quicker.
|
||||
|
||||
Docs are at https://godoc.org/golang.org/x/net/http2
|
||||
|
||||
Demo test server at https://http2.golang.org/
|
||||
|
||||
Help & bug reports welcome!
|
||||
|
||||
Contributing: https://golang.org/doc/contribute.html
|
||||
Bugs: https://golang.org/issue/new?title=x/net/http2:+
|
|
@ -0,0 +1,256 @@
|
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Transport code's client connection pooling.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ClientConnPool manages a pool of HTTP/2 client connections.
|
||||
type ClientConnPool interface {
|
||||
GetClientConn(req *http.Request, addr string) (*ClientConn, error)
|
||||
MarkDead(*ClientConn)
|
||||
}
|
||||
|
||||
// clientConnPoolIdleCloser is the interface implemented by ClientConnPool
|
||||
// implementations which can close their idle connections.
|
||||
type clientConnPoolIdleCloser interface {
|
||||
ClientConnPool
|
||||
closeIdleConnections()
|
||||
}
|
||||
|
||||
var (
|
||||
_ clientConnPoolIdleCloser = (*clientConnPool)(nil)
|
||||
_ clientConnPoolIdleCloser = noDialClientConnPool{}
|
||||
)
|
||||
|
||||
// TODO: use singleflight for dialing and addConnCalls?
|
||||
type clientConnPool struct {
|
||||
t *Transport
|
||||
|
||||
mu sync.Mutex // TODO: maybe switch to RWMutex
|
||||
// TODO: add support for sharing conns based on cert names
|
||||
// (e.g. share conn for googleapis.com and appspot.com)
|
||||
conns map[string][]*ClientConn // key is host:port
|
||||
dialing map[string]*dialCall // currently in-flight dials
|
||||
keys map[*ClientConn][]string
|
||||
addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
|
||||
}
|
||||
|
||||
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
|
||||
return p.getClientConn(req, addr, dialOnMiss)
|
||||
}
|
||||
|
||||
const (
|
||||
dialOnMiss = true
|
||||
noDialOnMiss = false
|
||||
)
|
||||
|
||||
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
|
||||
if isConnectionCloseRequest(req) && dialOnMiss {
|
||||
// It gets its own connection.
|
||||
const singleUse = true
|
||||
cc, err := p.t.dialClientConn(addr, singleUse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cc, nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
for _, cc := range p.conns[addr] {
|
||||
if cc.CanTakeNewRequest() {
|
||||
p.mu.Unlock()
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
if !dialOnMiss {
|
||||
p.mu.Unlock()
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
call := p.getStartDialLocked(addr)
|
||||
p.mu.Unlock()
|
||||
<-call.done
|
||||
return call.res, call.err
|
||||
}
|
||||
|
||||
// dialCall is an in-flight Transport dial call to a host.
|
||||
type dialCall struct {
|
||||
p *clientConnPool
|
||||
done chan struct{} // closed when done
|
||||
res *ClientConn // valid after done is closed
|
||||
err error // valid after done is closed
|
||||
}
|
||||
|
||||
// requires p.mu is held.
|
||||
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
|
||||
if call, ok := p.dialing[addr]; ok {
|
||||
// A dial is already in-flight. Don't start another.
|
||||
return call
|
||||
}
|
||||
call := &dialCall{p: p, done: make(chan struct{})}
|
||||
if p.dialing == nil {
|
||||
p.dialing = make(map[string]*dialCall)
|
||||
}
|
||||
p.dialing[addr] = call
|
||||
go call.dial(addr)
|
||||
return call
|
||||
}
|
||||
|
||||
// run in its own goroutine.
|
||||
func (c *dialCall) dial(addr string) {
|
||||
const singleUse = false // shared conn
|
||||
c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
|
||||
close(c.done)
|
||||
|
||||
c.p.mu.Lock()
|
||||
delete(c.p.dialing, addr)
|
||||
if c.err == nil {
|
||||
c.p.addConnLocked(addr, c.res)
|
||||
}
|
||||
c.p.mu.Unlock()
|
||||
}
|
||||
|
||||
// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
|
||||
// already exist. It coalesces concurrent calls with the same key.
|
||||
// This is used by the http1 Transport code when it creates a new connection. Because
|
||||
// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
|
||||
// the protocol), it can get into a situation where it has multiple TLS connections.
|
||||
// This code decides which ones live or die.
|
||||
// The return value used is whether c was used.
|
||||
// c is never closed.
|
||||
func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
|
||||
p.mu.Lock()
|
||||
for _, cc := range p.conns[key] {
|
||||
if cc.CanTakeNewRequest() {
|
||||
p.mu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
call, dup := p.addConnCalls[key]
|
||||
if !dup {
|
||||
if p.addConnCalls == nil {
|
||||
p.addConnCalls = make(map[string]*addConnCall)
|
||||
}
|
||||
call = &addConnCall{
|
||||
p: p,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
p.addConnCalls[key] = call
|
||||
go call.run(t, key, c)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
<-call.done
|
||||
if call.err != nil {
|
||||
return false, call.err
|
||||
}
|
||||
return !dup, nil
|
||||
}
|
||||
|
||||
type addConnCall struct {
|
||||
p *clientConnPool
|
||||
done chan struct{} // closed when done
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
|
||||
cc, err := t.NewClientConn(tc)
|
||||
|
||||
p := c.p
|
||||
p.mu.Lock()
|
||||
if err != nil {
|
||||
c.err = err
|
||||
} else {
|
||||
p.addConnLocked(key, cc)
|
||||
}
|
||||
delete(p.addConnCalls, key)
|
||||
p.mu.Unlock()
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
func (p *clientConnPool) addConn(key string, cc *ClientConn) {
|
||||
p.mu.Lock()
|
||||
p.addConnLocked(key, cc)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// p.mu must be held
|
||||
func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
|
||||
for _, v := range p.conns[key] {
|
||||
if v == cc {
|
||||
return
|
||||
}
|
||||
}
|
||||
if p.conns == nil {
|
||||
p.conns = make(map[string][]*ClientConn)
|
||||
}
|
||||
if p.keys == nil {
|
||||
p.keys = make(map[*ClientConn][]string)
|
||||
}
|
||||
p.conns[key] = append(p.conns[key], cc)
|
||||
p.keys[cc] = append(p.keys[cc], key)
|
||||
}
|
||||
|
||||
func (p *clientConnPool) MarkDead(cc *ClientConn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, key := range p.keys[cc] {
|
||||
vv, ok := p.conns[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
newList := filterOutClientConn(vv, cc)
|
||||
if len(newList) > 0 {
|
||||
p.conns[key] = newList
|
||||
} else {
|
||||
delete(p.conns, key)
|
||||
}
|
||||
}
|
||||
delete(p.keys, cc)
|
||||
}
|
||||
|
||||
func (p *clientConnPool) closeIdleConnections() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
// TODO: don't close a cc if it was just added to the pool
|
||||
// milliseconds ago and has never been used. There's currently
|
||||
// a small race window with the HTTP/1 Transport's integration
|
||||
// where it can add an idle conn just before using it, and
|
||||
// somebody else can concurrently call CloseIdleConns and
|
||||
// break some caller's RoundTrip.
|
||||
for _, vv := range p.conns {
|
||||
for _, cc := range vv {
|
||||
cc.closeIfIdle()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
|
||||
out := in[:0]
|
||||
for _, v := range in {
|
||||
if v != exclude {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
// If we filtered it out, zero out the last item to prevent
|
||||
// the GC from seeing it.
|
||||
if len(in) != len(out) {
|
||||
in[len(in)-1] = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// noDialClientConnPool is an implementation of http2.ClientConnPool
|
||||
// which never dials. We let the HTTP/1.1 client dial and use its TLS
|
||||
// connection instead.
|
||||
type noDialClientConnPool struct{ *clientConnPool }
|
||||
|
||||
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
|
||||
return p.getClientConn(req, addr, noDialOnMiss)
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.6
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func configureTransport(t1 *http.Transport) (*Transport, error) {
|
||||
connPool := new(clientConnPool)
|
||||
t2 := &Transport{
|
||||
ConnPool: noDialClientConnPool{connPool},
|
||||
t1: t1,
|
||||
}
|
||||
connPool.t = t2
|
||||
if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t1.TLSClientConfig == nil {
|
||||
t1.TLSClientConfig = new(tls.Config)
|
||||
}
|
||||
if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
|
||||
t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
|
||||
}
|
||||
if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
|
||||
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
|
||||
}
|
||||
upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
|
||||
addr := authorityAddr("https", authority)
|
||||
if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
|
||||
go c.Close()
|
||||
return erringRoundTripper{err}
|
||||
} else if !used {
|
||||
// Turns out we don't need this c.
|
||||
// For example, two goroutines made requests to the same host
|
||||
// at the same time, both kicking off TCP dials. (since protocol
|
||||
// was unknown)
|
||||
go c.Close()
|
||||
}
|
||||
return t2
|
||||
}
|
||||
if m := t1.TLSNextProto; len(m) == 0 {
|
||||
t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
|
||||
"h2": upgradeFn,
|
||||
}
|
||||
} else {
|
||||
m["h2"] = upgradeFn
|
||||
}
|
||||
return t2, nil
|
||||
}
|
||||
|
||||
// registerHTTPSProtocol calls Transport.RegisterProtocol but
|
||||
// convering panics into errors.
|
||||
func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("%v", e)
|
||||
}
|
||||
}()
|
||||
t.RegisterProtocol("https", rt)
|
||||
return nil
|
||||
}
|
||||
|
||||
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
|
||||
// if there's already has a cached connection to the host.
|
||||
type noDialH2RoundTripper struct{ t *Transport }
|
||||
|
||||
func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.t.RoundTrip(req)
|
||||
if err == ErrNoCachedConn {
|
||||
return nil, http.ErrSkipAltProtocol
|
||||
}
|
||||
return res, err
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
|
||||
type ErrCode uint32
|
||||
|
||||
const (
|
||||
ErrCodeNo ErrCode = 0x0
|
||||
ErrCodeProtocol ErrCode = 0x1
|
||||
ErrCodeInternal ErrCode = 0x2
|
||||
ErrCodeFlowControl ErrCode = 0x3
|
||||
ErrCodeSettingsTimeout ErrCode = 0x4
|
||||
ErrCodeStreamClosed ErrCode = 0x5
|
||||
ErrCodeFrameSize ErrCode = 0x6
|
||||
ErrCodeRefusedStream ErrCode = 0x7
|
||||
ErrCodeCancel ErrCode = 0x8
|
||||
ErrCodeCompression ErrCode = 0x9
|
||||
ErrCodeConnect ErrCode = 0xa
|
||||
ErrCodeEnhanceYourCalm ErrCode = 0xb
|
||||
ErrCodeInadequateSecurity ErrCode = 0xc
|
||||
ErrCodeHTTP11Required ErrCode = 0xd
|
||||
)
|
||||
|
||||
var errCodeName = map[ErrCode]string{
|
||||
ErrCodeNo: "NO_ERROR",
|
||||
ErrCodeProtocol: "PROTOCOL_ERROR",
|
||||
ErrCodeInternal: "INTERNAL_ERROR",
|
||||
ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
|
||||
ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
|
||||
ErrCodeStreamClosed: "STREAM_CLOSED",
|
||||
ErrCodeFrameSize: "FRAME_SIZE_ERROR",
|
||||
ErrCodeRefusedStream: "REFUSED_STREAM",
|
||||
ErrCodeCancel: "CANCEL",
|
||||
ErrCodeCompression: "COMPRESSION_ERROR",
|
||||
ErrCodeConnect: "CONNECT_ERROR",
|
||||
ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
|
||||
ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
|
||||
ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
|
||||
}
|
||||
|
||||
func (e ErrCode) String() string {
|
||||
if s, ok := errCodeName[e]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
|
||||
}
|
||||
|
||||
// ConnectionError is an error that results in the termination of the
|
||||
// entire connection.
|
||||
type ConnectionError ErrCode
|
||||
|
||||
func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: %s", ErrCode(e)) }
|
||||
|
||||
// StreamError is an error that only affects one stream within an
|
||||
// HTTP/2 connection.
|
||||
type StreamError struct {
|
||||
StreamID uint32
|
||||
Code ErrCode
|
||||
Cause error // optional additional detail
|
||||
}
|
||||
|
||||
func streamError(id uint32, code ErrCode) StreamError {
|
||||
return StreamError{StreamID: id, Code: code}
|
||||
}
|
||||
|
||||
func (e StreamError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
|
||||
}
|
||||
|
||||
// 6.9.1 The Flow Control Window
|
||||
// "If a sender receives a WINDOW_UPDATE that causes a flow control
|
||||
// window to exceed this maximum it MUST terminate either the stream
|
||||
// or the connection, as appropriate. For streams, [...]; for the
|
||||
// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
|
||||
type goAwayFlowError struct{}
|
||||
|
||||
func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
|
||||
|
||||
// connErrorReason wraps a ConnectionError with an informative error about why it occurs.
|
||||
|
||||
// Errors of this type are only returned by the frame parser functions
|
||||
// and converted into ConnectionError(ErrCodeProtocol).
|
||||
type connError struct {
|
||||
Code ErrCode
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e connError) Error() string {
|
||||
return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
|
||||
}
|
||||
|
||||
type pseudoHeaderError string
|
||||
|
||||
func (e pseudoHeaderError) Error() string {
|
||||
return fmt.Sprintf("invalid pseudo-header %q", string(e))
|
||||
}
|
||||
|
||||
type duplicatePseudoHeaderError string
|
||||
|
||||
func (e duplicatePseudoHeaderError) Error() string {
|
||||
return fmt.Sprintf("duplicate pseudo-header %q", string(e))
|
||||
}
|
||||
|
||||
type headerFieldNameError string
|
||||
|
||||
func (e headerFieldNameError) Error() string {
|
||||
return fmt.Sprintf("invalid header field name %q", string(e))
|
||||
}
|
||||
|
||||
type headerFieldValueError string
|
||||
|
||||
func (e headerFieldValueError) Error() string {
|
||||
return fmt.Sprintf("invalid header field value %q", string(e))
|
||||
}
|
||||
|
||||
var (
|
||||
errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers")
|
||||
errPseudoAfterRegular = errors.New("pseudo header field after regular")
|
||||
)
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// fixedBuffer is an io.ReadWriter backed by a fixed size buffer.
|
||||
// It never allocates, but moves old data as new data is written.
|
||||
type fixedBuffer struct {
|
||||
buf []byte
|
||||
r, w int
|
||||
}
|
||||
|
||||
var (
|
||||
errReadEmpty = errors.New("read from empty fixedBuffer")
|
||||
errWriteFull = errors.New("write on full fixedBuffer")
|
||||
)
|
||||
|
||||
// Read copies bytes from the buffer into p.
|
||||
// It is an error to read when no data is available.
|
||||
func (b *fixedBuffer) Read(p []byte) (n int, err error) {
|
||||
if b.r == b.w {
|
||||
return 0, errReadEmpty
|
||||
}
|
||||
n = copy(p, b.buf[b.r:b.w])
|
||||
b.r += n
|
||||
if b.r == b.w {
|
||||
b.r = 0
|
||||
b.w = 0
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Len returns the number of bytes of the unread portion of the buffer.
|
||||
func (b *fixedBuffer) Len() int {
|
||||
return b.w - b.r
|
||||
}
|
||||
|
||||
// Write copies bytes from p into the buffer.
|
||||
// It is an error to write more data than the buffer can hold.
|
||||
func (b *fixedBuffer) Write(p []byte) (n int, err error) {
|
||||
// Slide existing data to beginning.
|
||||
if b.r > 0 && len(p) > len(b.buf)-b.w {
|
||||
copy(b.buf, b.buf[b.r:b.w])
|
||||
b.w -= b.r
|
||||
b.r = 0
|
||||
}
|
||||
|
||||
// Write new data.
|
||||
n = copy(b.buf[b.w:], p)
|
||||
b.w += n
|
||||
if n < len(p) {
|
||||
err = errWriteFull
|
||||
}
|
||||
return n, err
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Flow control
|
||||
|
||||
package http2
|
||||
|
||||
// flow is the flow control window's size.
|
||||
type flow struct {
|
||||
// n is the number of DATA bytes we're allowed to send.
|
||||
// A flow is kept both on a conn and a per-stream.
|
||||
n int32
|
||||
|
||||
// conn points to the shared connection-level flow that is
|
||||
// shared by all streams on that conn. It is nil for the flow
|
||||
// that's on the conn directly.
|
||||
conn *flow
|
||||
}
|
||||
|
||||
func (f *flow) setConnFlow(cf *flow) { f.conn = cf }
|
||||
|
||||
func (f *flow) available() int32 {
|
||||
n := f.n
|
||||
if f.conn != nil && f.conn.n < n {
|
||||
n = f.conn.n
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (f *flow) take(n int32) {
|
||||
if n > f.available() {
|
||||
panic("internal error: took too much")
|
||||
}
|
||||
f.n -= n
|
||||
if f.conn != nil {
|
||||
f.conn.n -= n
|
||||
}
|
||||
}
|
||||
|
||||
// add adds n bytes (positive or negative) to the flow control window.
|
||||
// It returns false if the sum would exceed 2^31-1.
|
||||
func (f *flow) add(n int32) bool {
|
||||
remain := (1<<31 - 1) - f.n
|
||||
if n > remain {
|
||||
return false
|
||||
}
|
||||
f.n += n
|
||||
return true
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,43 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.6
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func transportExpectContinueTimeout(t1 *http.Transport) time.Duration {
|
||||
return t1.ExpectContinueTimeout
|
||||
}
|
||||
|
||||
// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec.
|
||||
func isBadCipher(cipher uint16) bool {
|
||||
switch cipher {
|
||||
case tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
|
||||
// Reject cipher suites from Appendix A.
|
||||
// "This list includes those cipher suites that do not
|
||||
// offer an ephemeral key exchange and those that are
|
||||
// based on the TLS null, stream or block cipher type"
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.7
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"time"
|
||||
)
|
||||
|
||||
type contextContext interface {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
|
||||
if hs := opts.baseConfig(); hs != nil {
|
||||
ctx = context.WithValue(ctx, http.ServerContextKey, hs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
|
||||
func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
||||
type clientTrace httptrace.ClientTrace
|
||||
|
||||
func reqContext(r *http.Request) context.Context { return r.Context() }
|
||||
|
||||
func setResponseUncompressed(res *http.Response) { res.Uncompressed = true }
|
||||
|
||||
func traceGotConn(req *http.Request, cc *ClientConn) {
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
if trace == nil || trace.GotConn == nil {
|
||||
return
|
||||
}
|
||||
ci := httptrace.GotConnInfo{Conn: cc.tconn}
|
||||
cc.mu.Lock()
|
||||
ci.Reused = cc.nextStreamID > 1
|
||||
ci.WasIdle = len(cc.streams) == 0 && ci.Reused
|
||||
if ci.WasIdle && !cc.lastActive.IsZero() {
|
||||
ci.IdleTime = time.Now().Sub(cc.lastActive)
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
|
||||
trace.GotConn(ci)
|
||||
}
|
||||
|
||||
func traceWroteHeaders(trace *clientTrace) {
|
||||
if trace != nil && trace.WroteHeaders != nil {
|
||||
trace.WroteHeaders()
|
||||
}
|
||||
}
|
||||
|
||||
func traceGot100Continue(trace *clientTrace) {
|
||||
if trace != nil && trace.Got100Continue != nil {
|
||||
trace.Got100Continue()
|
||||
}
|
||||
}
|
||||
|
||||
func traceWait100Continue(trace *clientTrace) {
|
||||
if trace != nil && trace.Wait100Continue != nil {
|
||||
trace.Wait100Continue()
|
||||
}
|
||||
}
|
||||
|
||||
func traceWroteRequest(trace *clientTrace, err error) {
|
||||
if trace != nil && trace.WroteRequest != nil {
|
||||
trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
|
||||
}
|
||||
}
|
||||
|
||||
func traceFirstResponseByte(trace *clientTrace) {
|
||||
if trace != nil && trace.GotFirstResponseByte != nil {
|
||||
trace.GotFirstResponseByte()
|
||||
}
|
||||
}
|
||||
|
||||
func requestTrace(req *http.Request) *clientTrace {
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
return (*clientTrace)(trace)
|
||||
}
|
|
@ -0,0 +1,170 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Defensive debug-only utility to track that functions run on the
|
||||
// goroutine that they're supposed to.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
|
||||
|
||||
type goroutineLock uint64
|
||||
|
||||
func newGoroutineLock() goroutineLock {
|
||||
if !DebugGoroutines {
|
||||
return 0
|
||||
}
|
||||
return goroutineLock(curGoroutineID())
|
||||
}
|
||||
|
||||
func (g goroutineLock) check() {
|
||||
if !DebugGoroutines {
|
||||
return
|
||||
}
|
||||
if curGoroutineID() != uint64(g) {
|
||||
panic("running on the wrong goroutine")
|
||||
}
|
||||
}
|
||||
|
||||
func (g goroutineLock) checkNotOn() {
|
||||
if !DebugGoroutines {
|
||||
return
|
||||
}
|
||||
if curGoroutineID() == uint64(g) {
|
||||
panic("running on the wrong goroutine")
|
||||
}
|
||||
}
|
||||
|
||||
var goroutineSpace = []byte("goroutine ")
|
||||
|
||||
func curGoroutineID() uint64 {
|
||||
bp := littleBuf.Get().(*[]byte)
|
||||
defer littleBuf.Put(bp)
|
||||
b := *bp
|
||||
b = b[:runtime.Stack(b, false)]
|
||||
// Parse the 4707 out of "goroutine 4707 ["
|
||||
b = bytes.TrimPrefix(b, goroutineSpace)
|
||||
i := bytes.IndexByte(b, ' ')
|
||||
if i < 0 {
|
||||
panic(fmt.Sprintf("No space found in %q", b))
|
||||
}
|
||||
b = b[:i]
|
||||
n, err := parseUintBytes(b, 10, 64)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
var littleBuf = sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 64)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// parseUintBytes is like strconv.ParseUint, but using a []byte.
|
||||
func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
|
||||
var cutoff, maxVal uint64
|
||||
|
||||
if bitSize == 0 {
|
||||
bitSize = int(strconv.IntSize)
|
||||
}
|
||||
|
||||
s0 := s
|
||||
switch {
|
||||
case len(s) < 1:
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
|
||||
case 2 <= base && base <= 36:
|
||||
// valid base; nothing to do
|
||||
|
||||
case base == 0:
|
||||
// Look for octal, hex prefix.
|
||||
switch {
|
||||
case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
|
||||
base = 16
|
||||
s = s[2:]
|
||||
if len(s) < 1 {
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
}
|
||||
case s[0] == '0':
|
||||
base = 8
|
||||
default:
|
||||
base = 10
|
||||
}
|
||||
|
||||
default:
|
||||
err = errors.New("invalid base " + strconv.Itoa(base))
|
||||
goto Error
|
||||
}
|
||||
|
||||
n = 0
|
||||
cutoff = cutoff64(base)
|
||||
maxVal = 1<<uint(bitSize) - 1
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
var v byte
|
||||
d := s[i]
|
||||
switch {
|
||||
case '0' <= d && d <= '9':
|
||||
v = d - '0'
|
||||
case 'a' <= d && d <= 'z':
|
||||
v = d - 'a' + 10
|
||||
case 'A' <= d && d <= 'Z':
|
||||
v = d - 'A' + 10
|
||||
default:
|
||||
n = 0
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
}
|
||||
if int(v) >= base {
|
||||
n = 0
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
}
|
||||
|
||||
if n >= cutoff {
|
||||
// n*base overflows
|
||||
n = 1<<64 - 1
|
||||
err = strconv.ErrRange
|
||||
goto Error
|
||||
}
|
||||
n *= uint64(base)
|
||||
|
||||
n1 := n + uint64(v)
|
||||
if n1 < n || n1 > maxVal {
|
||||
// n+v overflows
|
||||
n = 1<<64 - 1
|
||||
err = strconv.ErrRange
|
||||
goto Error
|
||||
}
|
||||
n = n1
|
||||
}
|
||||
|
||||
return n, nil
|
||||
|
||||
Error:
|
||||
return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
|
||||
}
|
||||
|
||||
// Return the first number n such that n*base >= 1<<64.
|
||||
func cutoff64(base int) uint64 {
|
||||
if base < 2 {
|
||||
return 0
|
||||
}
|
||||
return (1<<64-1)/uint64(base) + 1
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case
|
||||
commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case
|
||||
)
|
||||
|
||||
func init() {
|
||||
for _, v := range []string{
|
||||
"accept",
|
||||
"accept-charset",
|
||||
"accept-encoding",
|
||||
"accept-language",
|
||||
"accept-ranges",
|
||||
"age",
|
||||
"access-control-allow-origin",
|
||||
"allow",
|
||||
"authorization",
|
||||
"cache-control",
|
||||
"content-disposition",
|
||||
"content-encoding",
|
||||
"content-language",
|
||||
"content-length",
|
||||
"content-location",
|
||||
"content-range",
|
||||
"content-type",
|
||||
"cookie",
|
||||
"date",
|
||||
"etag",
|
||||
"expect",
|
||||
"expires",
|
||||
"from",
|
||||
"host",
|
||||
"if-match",
|
||||
"if-modified-since",
|
||||
"if-none-match",
|
||||
"if-unmodified-since",
|
||||
"last-modified",
|
||||
"link",
|
||||
"location",
|
||||
"max-forwards",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"range",
|
||||
"referer",
|
||||
"refresh",
|
||||
"retry-after",
|
||||
"server",
|
||||
"set-cookie",
|
||||
"strict-transport-security",
|
||||
"trailer",
|
||||
"transfer-encoding",
|
||||
"user-agent",
|
||||
"vary",
|
||||
"via",
|
||||
"www-authenticate",
|
||||
} {
|
||||
chk := http.CanonicalHeaderKey(v)
|
||||
commonLowerHeader[chk] = v
|
||||
commonCanonHeader[v] = chk
|
||||
}
|
||||
}
|
||||
|
||||
func lowerHeader(v string) string {
|
||||
if s, ok := commonLowerHeader[v]; ok {
|
||||
return s
|
||||
}
|
||||
return strings.ToLower(v)
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
uint32Max = ^uint32(0)
|
||||
initialHeaderTableSize = 4096
|
||||
)
|
||||
|
||||
type Encoder struct {
|
||||
dynTab dynamicTable
|
||||
// minSize is the minimum table size set by
|
||||
// SetMaxDynamicTableSize after the previous Header Table Size
|
||||
// Update.
|
||||
minSize uint32
|
||||
// maxSizeLimit is the maximum table size this encoder
|
||||
// supports. This will protect the encoder from too large
|
||||
// size.
|
||||
maxSizeLimit uint32
|
||||
// tableSizeUpdate indicates whether "Header Table Size
|
||||
// Update" is required.
|
||||
tableSizeUpdate bool
|
||||
w io.Writer
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// NewEncoder returns a new Encoder which performs HPACK encoding. An
|
||||
// encoded data is written to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
e := &Encoder{
|
||||
minSize: uint32Max,
|
||||
maxSizeLimit: initialHeaderTableSize,
|
||||
tableSizeUpdate: false,
|
||||
w: w,
|
||||
}
|
||||
e.dynTab.setMaxSize(initialHeaderTableSize)
|
||||
return e
|
||||
}
|
||||
|
||||
// WriteField encodes f into a single Write to e's underlying Writer.
|
||||
// This function may also produce bytes for "Header Table Size Update"
|
||||
// if necessary. If produced, it is done before encoding f.
|
||||
func (e *Encoder) WriteField(f HeaderField) error {
|
||||
e.buf = e.buf[:0]
|
||||
|
||||
if e.tableSizeUpdate {
|
||||
e.tableSizeUpdate = false
|
||||
if e.minSize < e.dynTab.maxSize {
|
||||
e.buf = appendTableSize(e.buf, e.minSize)
|
||||
}
|
||||
e.minSize = uint32Max
|
||||
e.buf = appendTableSize(e.buf, e.dynTab.maxSize)
|
||||
}
|
||||
|
||||
idx, nameValueMatch := e.searchTable(f)
|
||||
if nameValueMatch {
|
||||
e.buf = appendIndexed(e.buf, idx)
|
||||
} else {
|
||||
indexing := e.shouldIndex(f)
|
||||
if indexing {
|
||||
e.dynTab.add(f)
|
||||
}
|
||||
|
||||
if idx == 0 {
|
||||
e.buf = appendNewName(e.buf, f, indexing)
|
||||
} else {
|
||||
e.buf = appendIndexedName(e.buf, f, idx, indexing)
|
||||
}
|
||||
}
|
||||
n, err := e.w.Write(e.buf)
|
||||
if err == nil && n != len(e.buf) {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// searchTable searches f in both stable and dynamic header tables.
|
||||
// The static header table is searched first. Only when there is no
|
||||
// exact match for both name and value, the dynamic header table is
|
||||
// then searched. If there is no match, i is 0. If both name and value
|
||||
// match, i is the matched index and nameValueMatch becomes true. If
|
||||
// only name matches, i points to that index and nameValueMatch
|
||||
// becomes false.
|
||||
func (e *Encoder) searchTable(f HeaderField) (i uint64, nameValueMatch bool) {
|
||||
for idx, hf := range staticTable {
|
||||
if !constantTimeStringCompare(hf.Name, f.Name) {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
i = uint64(idx + 1)
|
||||
}
|
||||
if f.Sensitive {
|
||||
continue
|
||||
}
|
||||
if !constantTimeStringCompare(hf.Value, f.Value) {
|
||||
continue
|
||||
}
|
||||
i = uint64(idx + 1)
|
||||
nameValueMatch = true
|
||||
return
|
||||
}
|
||||
|
||||
j, nameValueMatch := e.dynTab.search(f)
|
||||
if nameValueMatch || (i == 0 && j != 0) {
|
||||
i = j + uint64(len(staticTable))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetMaxDynamicTableSize changes the dynamic header table size to v.
|
||||
// The actual size is bounded by the value passed to
|
||||
// SetMaxDynamicTableSizeLimit.
|
||||
func (e *Encoder) SetMaxDynamicTableSize(v uint32) {
|
||||
if v > e.maxSizeLimit {
|
||||
v = e.maxSizeLimit
|
||||
}
|
||||
if v < e.minSize {
|
||||
e.minSize = v
|
||||
}
|
||||
e.tableSizeUpdate = true
|
||||
e.dynTab.setMaxSize(v)
|
||||
}
|
||||
|
||||
// SetMaxDynamicTableSizeLimit changes the maximum value that can be
|
||||
// specified in SetMaxDynamicTableSize to v. By default, it is set to
|
||||
// 4096, which is the same size of the default dynamic header table
|
||||
// size described in HPACK specification. If the current maximum
|
||||
// dynamic header table size is strictly greater than v, "Header Table
|
||||
// Size Update" will be done in the next WriteField call and the
|
||||
// maximum dynamic header table size is truncated to v.
|
||||
func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
|
||||
e.maxSizeLimit = v
|
||||
if e.dynTab.maxSize > v {
|
||||
e.tableSizeUpdate = true
|
||||
e.dynTab.setMaxSize(v)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldIndex reports whether f should be indexed.
|
||||
func (e *Encoder) shouldIndex(f HeaderField) bool {
|
||||
return !f.Sensitive && f.Size() <= e.dynTab.maxSize
|
||||
}
|
||||
|
||||
// appendIndexed appends index i, as encoded in "Indexed Header Field"
|
||||
// representation, to dst and returns the extended buffer.
|
||||
func appendIndexed(dst []byte, i uint64) []byte {
|
||||
first := len(dst)
|
||||
dst = appendVarInt(dst, 7, i)
|
||||
dst[first] |= 0x80
|
||||
return dst
|
||||
}
|
||||
|
||||
// appendNewName appends f, as encoded in one of "Literal Header field
|
||||
// - New Name" representation variants, to dst and returns the
|
||||
// extended buffer.
|
||||
//
|
||||
// If f.Sensitive is true, "Never Indexed" representation is used. If
|
||||
// f.Sensitive is false and indexing is true, "Inremental Indexing"
|
||||
// representation is used.
|
||||
func appendNewName(dst []byte, f HeaderField, indexing bool) []byte {
|
||||
dst = append(dst, encodeTypeByte(indexing, f.Sensitive))
|
||||
dst = appendHpackString(dst, f.Name)
|
||||
return appendHpackString(dst, f.Value)
|
||||
}
|
||||
|
||||
// appendIndexedName appends f and index i referring indexed name
|
||||
// entry, as encoded in one of "Literal Header field - Indexed Name"
|
||||
// representation variants, to dst and returns the extended buffer.
|
||||
//
|
||||
// If f.Sensitive is true, "Never Indexed" representation is used. If
|
||||
// f.Sensitive is false and indexing is true, "Incremental Indexing"
|
||||
// representation is used.
|
||||
func appendIndexedName(dst []byte, f HeaderField, i uint64, indexing bool) []byte {
|
||||
first := len(dst)
|
||||
var n byte
|
||||
if indexing {
|
||||
n = 6
|
||||
} else {
|
||||
n = 4
|
||||
}
|
||||
dst = appendVarInt(dst, n, i)
|
||||
dst[first] |= encodeTypeByte(indexing, f.Sensitive)
|
||||
return appendHpackString(dst, f.Value)
|
||||
}
|
||||
|
||||
// appendTableSize appends v, as encoded in "Header Table Size Update"
|
||||
// representation, to dst and returns the extended buffer.
|
||||
func appendTableSize(dst []byte, v uint32) []byte {
|
||||
first := len(dst)
|
||||
dst = appendVarInt(dst, 5, uint64(v))
|
||||
dst[first] |= 0x20
|
||||
return dst
|
||||
}
|
||||
|
||||
// appendVarInt appends i, as encoded in variable integer form using n
|
||||
// bit prefix, to dst and returns the extended buffer.
|
||||
//
|
||||
// See
|
||||
// http://http2.github.io/http2-spec/compression.html#integer.representation
|
||||
func appendVarInt(dst []byte, n byte, i uint64) []byte {
|
||||
k := uint64((1 << n) - 1)
|
||||
if i < k {
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
dst = append(dst, byte(k))
|
||||
i -= k
|
||||
for ; i >= 128; i >>= 7 {
|
||||
dst = append(dst, byte(0x80|(i&0x7f)))
|
||||
}
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
|
||||
// appendHpackString appends s, as encoded in "String Literal"
|
||||
// representation, to dst and returns the the extended buffer.
|
||||
//
|
||||
// s will be encoded in Huffman codes only when it produces strictly
|
||||
// shorter byte string.
|
||||
func appendHpackString(dst []byte, s string) []byte {
|
||||
huffmanLength := HuffmanEncodeLength(s)
|
||||
if huffmanLength < uint64(len(s)) {
|
||||
first := len(dst)
|
||||
dst = appendVarInt(dst, 7, huffmanLength)
|
||||
dst = AppendHuffmanString(dst, s)
|
||||
dst[first] |= 0x80
|
||||
} else {
|
||||
dst = appendVarInt(dst, 7, uint64(len(s)))
|
||||
dst = append(dst, s...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// encodeTypeByte returns type byte. If sensitive is true, type byte
|
||||
// for "Never Indexed" representation is returned. If sensitive is
|
||||
// false and indexing is true, type byte for "Incremental Indexing"
|
||||
// representation is returned. Otherwise, type byte for "Without
|
||||
// Indexing" is returned.
|
||||
func encodeTypeByte(indexing, sensitive bool) byte {
|
||||
if sensitive {
|
||||
return 0x10
|
||||
}
|
||||
if indexing {
|
||||
return 0x40
|
||||
}
|
||||
return 0
|
||||
}
|
|
@ -0,0 +1,542 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package hpack implements HPACK, a compression format for
|
||||
// efficiently representing HTTP header fields in the context of HTTP/2.
|
||||
//
|
||||
// See http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-09
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// A DecodingError is something the spec defines as a decoding error.
|
||||
type DecodingError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (de DecodingError) Error() string {
|
||||
return fmt.Sprintf("decoding error: %v", de.Err)
|
||||
}
|
||||
|
||||
// An InvalidIndexError is returned when an encoder references a table
|
||||
// entry before the static table or after the end of the dynamic table.
|
||||
type InvalidIndexError int
|
||||
|
||||
func (e InvalidIndexError) Error() string {
|
||||
return fmt.Sprintf("invalid indexed representation index %d", int(e))
|
||||
}
|
||||
|
||||
// A HeaderField is a name-value pair. Both the name and value are
|
||||
// treated as opaque sequences of octets.
|
||||
type HeaderField struct {
|
||||
Name, Value string
|
||||
|
||||
// Sensitive means that this header field should never be
|
||||
// indexed.
|
||||
Sensitive bool
|
||||
}
|
||||
|
||||
// IsPseudo reports whether the header field is an http2 pseudo header.
|
||||
// That is, it reports whether it starts with a colon.
|
||||
// It is not otherwise guaranteed to be a valid pseudo header field,
|
||||
// though.
|
||||
func (hf HeaderField) IsPseudo() bool {
|
||||
return len(hf.Name) != 0 && hf.Name[0] == ':'
|
||||
}
|
||||
|
||||
func (hf HeaderField) String() string {
|
||||
var suffix string
|
||||
if hf.Sensitive {
|
||||
suffix = " (sensitive)"
|
||||
}
|
||||
return fmt.Sprintf("header field %q = %q%s", hf.Name, hf.Value, suffix)
|
||||
}
|
||||
|
||||
// Size returns the size of an entry per RFC 7540 section 5.2.
|
||||
func (hf HeaderField) Size() uint32 {
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
|
||||
// "The size of the dynamic table is the sum of the size of
|
||||
// its entries. The size of an entry is the sum of its name's
|
||||
// length in octets (as defined in Section 5.2), its value's
|
||||
// length in octets (see Section 5.2), plus 32. The size of
|
||||
// an entry is calculated using the length of the name and
|
||||
// value without any Huffman encoding applied."
|
||||
|
||||
// This can overflow if somebody makes a large HeaderField
|
||||
// Name and/or Value by hand, but we don't care, because that
|
||||
// won't happen on the wire because the encoding doesn't allow
|
||||
// it.
|
||||
return uint32(len(hf.Name) + len(hf.Value) + 32)
|
||||
}
|
||||
|
||||
// A Decoder is the decoding context for incremental processing of
|
||||
// header blocks.
|
||||
type Decoder struct {
|
||||
dynTab dynamicTable
|
||||
emit func(f HeaderField)
|
||||
|
||||
emitEnabled bool // whether calls to emit are enabled
|
||||
maxStrLen int // 0 means unlimited
|
||||
|
||||
// buf is the unparsed buffer. It's only written to
|
||||
// saveBuf if it was truncated in the middle of a header
|
||||
// block. Because it's usually not owned, we can only
|
||||
// process it under Write.
|
||||
buf []byte // not owned; only valid during Write
|
||||
|
||||
// saveBuf is previous data passed to Write which we weren't able
|
||||
// to fully parse before. Unlike buf, we own this data.
|
||||
saveBuf bytes.Buffer
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder with the provided maximum dynamic
|
||||
// table size. The emitFunc will be called for each valid field
|
||||
// parsed, in the same goroutine as calls to Write, before Write returns.
|
||||
func NewDecoder(maxDynamicTableSize uint32, emitFunc func(f HeaderField)) *Decoder {
|
||||
d := &Decoder{
|
||||
emit: emitFunc,
|
||||
emitEnabled: true,
|
||||
}
|
||||
d.dynTab.allowedMaxSize = maxDynamicTableSize
|
||||
d.dynTab.setMaxSize(maxDynamicTableSize)
|
||||
return d
|
||||
}
|
||||
|
||||
// ErrStringLength is returned by Decoder.Write when the max string length
|
||||
// (as configured by Decoder.SetMaxStringLength) would be violated.
|
||||
var ErrStringLength = errors.New("hpack: string too long")
|
||||
|
||||
// SetMaxStringLength sets the maximum size of a HeaderField name or
|
||||
// value string. If a string exceeds this length (even after any
|
||||
// decompression), Write will return ErrStringLength.
|
||||
// A value of 0 means unlimited and is the default from NewDecoder.
|
||||
func (d *Decoder) SetMaxStringLength(n int) {
|
||||
d.maxStrLen = n
|
||||
}
|
||||
|
||||
// SetEmitFunc changes the callback used when new header fields
|
||||
// are decoded.
|
||||
// It must be non-nil. It does not affect EmitEnabled.
|
||||
func (d *Decoder) SetEmitFunc(emitFunc func(f HeaderField)) {
|
||||
d.emit = emitFunc
|
||||
}
|
||||
|
||||
// SetEmitEnabled controls whether the emitFunc provided to NewDecoder
|
||||
// should be called. The default is true.
|
||||
//
|
||||
// This facility exists to let servers enforce MAX_HEADER_LIST_SIZE
|
||||
// while still decoding and keeping in-sync with decoder state, but
|
||||
// without doing unnecessary decompression or generating unnecessary
|
||||
// garbage for header fields past the limit.
|
||||
func (d *Decoder) SetEmitEnabled(v bool) { d.emitEnabled = v }
|
||||
|
||||
// EmitEnabled reports whether calls to the emitFunc provided to NewDecoder
|
||||
// are currently enabled. The default is true.
|
||||
func (d *Decoder) EmitEnabled() bool { return d.emitEnabled }
|
||||
|
||||
// TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their
|
||||
// underlying buffers for garbage reasons.
|
||||
|
||||
func (d *Decoder) SetMaxDynamicTableSize(v uint32) {
|
||||
d.dynTab.setMaxSize(v)
|
||||
}
|
||||
|
||||
// SetAllowedMaxDynamicTableSize sets the upper bound that the encoded
|
||||
// stream (via dynamic table size updates) may set the maximum size
|
||||
// to.
|
||||
func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) {
|
||||
d.dynTab.allowedMaxSize = v
|
||||
}
|
||||
|
||||
type dynamicTable struct {
|
||||
// ents is the FIFO described at
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2
|
||||
// The newest (low index) is append at the end, and items are
|
||||
// evicted from the front.
|
||||
ents []HeaderField
|
||||
size uint32
|
||||
maxSize uint32 // current maxSize
|
||||
allowedMaxSize uint32 // maxSize may go up to this, inclusive
|
||||
}
|
||||
|
||||
func (dt *dynamicTable) setMaxSize(v uint32) {
|
||||
dt.maxSize = v
|
||||
dt.evict()
|
||||
}
|
||||
|
||||
// TODO: change dynamicTable to be a struct with a slice and a size int field,
|
||||
// per http://http2.github.io/http2-spec/compression.html#rfc.section.4.1:
|
||||
//
|
||||
//
|
||||
// Then make add increment the size. maybe the max size should move from Decoder to
|
||||
// dynamicTable and add should return an ok bool if there was enough space.
|
||||
//
|
||||
// Later we'll need a remove operation on dynamicTable.
|
||||
|
||||
func (dt *dynamicTable) add(f HeaderField) {
|
||||
dt.ents = append(dt.ents, f)
|
||||
dt.size += f.Size()
|
||||
dt.evict()
|
||||
}
|
||||
|
||||
// If we're too big, evict old stuff (front of the slice)
|
||||
func (dt *dynamicTable) evict() {
|
||||
base := dt.ents // keep base pointer of slice
|
||||
for dt.size > dt.maxSize {
|
||||
dt.size -= dt.ents[0].Size()
|
||||
dt.ents = dt.ents[1:]
|
||||
}
|
||||
|
||||
// Shift slice contents down if we evicted things.
|
||||
if len(dt.ents) != len(base) {
|
||||
copy(base, dt.ents)
|
||||
dt.ents = base[:len(dt.ents)]
|
||||
}
|
||||
}
|
||||
|
||||
// constantTimeStringCompare compares string a and b in a constant
|
||||
// time manner.
|
||||
func constantTimeStringCompare(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
c := byte(0)
|
||||
|
||||
for i := 0; i < len(a); i++ {
|
||||
c |= a[i] ^ b[i]
|
||||
}
|
||||
|
||||
return c == 0
|
||||
}
|
||||
|
||||
// Search searches f in the table. The return value i is 0 if there is
|
||||
// no name match. If there is name match or name/value match, i is the
|
||||
// index of that entry (1-based). If both name and value match,
|
||||
// nameValueMatch becomes true.
|
||||
func (dt *dynamicTable) search(f HeaderField) (i uint64, nameValueMatch bool) {
|
||||
l := len(dt.ents)
|
||||
for j := l - 1; j >= 0; j-- {
|
||||
ent := dt.ents[j]
|
||||
if !constantTimeStringCompare(ent.Name, f.Name) {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
i = uint64(l - j)
|
||||
}
|
||||
if f.Sensitive {
|
||||
continue
|
||||
}
|
||||
if !constantTimeStringCompare(ent.Value, f.Value) {
|
||||
continue
|
||||
}
|
||||
i = uint64(l - j)
|
||||
nameValueMatch = true
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Decoder) maxTableIndex() int {
|
||||
return len(d.dynTab.ents) + len(staticTable)
|
||||
}
|
||||
|
||||
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
|
||||
if i < 1 {
|
||||
return
|
||||
}
|
||||
if i > uint64(d.maxTableIndex()) {
|
||||
return
|
||||
}
|
||||
if i <= uint64(len(staticTable)) {
|
||||
return staticTable[i-1], true
|
||||
}
|
||||
dents := d.dynTab.ents
|
||||
return dents[len(dents)-(int(i)-len(staticTable))], true
|
||||
}
|
||||
|
||||
// Decode decodes an entire block.
|
||||
//
|
||||
// TODO: remove this method and make it incremental later? This is
|
||||
// easier for debugging now.
|
||||
func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
|
||||
var hf []HeaderField
|
||||
saveFunc := d.emit
|
||||
defer func() { d.emit = saveFunc }()
|
||||
d.emit = func(f HeaderField) { hf = append(hf, f) }
|
||||
if _, err := d.Write(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hf, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) Close() error {
|
||||
if d.saveBuf.Len() > 0 {
|
||||
d.saveBuf.Reset()
|
||||
return DecodingError{errors.New("truncated headers")}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Decoder) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
// Prevent state machine CPU attacks (making us redo
|
||||
// work up to the point of finding out we don't have
|
||||
// enough data)
|
||||
return
|
||||
}
|
||||
// Only copy the data if we have to. Optimistically assume
|
||||
// that p will contain a complete header block.
|
||||
if d.saveBuf.Len() == 0 {
|
||||
d.buf = p
|
||||
} else {
|
||||
d.saveBuf.Write(p)
|
||||
d.buf = d.saveBuf.Bytes()
|
||||
d.saveBuf.Reset()
|
||||
}
|
||||
|
||||
for len(d.buf) > 0 {
|
||||
err = d.parseHeaderFieldRepr()
|
||||
if err == errNeedMore {
|
||||
// Extra paranoia, making sure saveBuf won't
|
||||
// get too large. All the varint and string
|
||||
// reading code earlier should already catch
|
||||
// overlong things and return ErrStringLength,
|
||||
// but keep this as a last resort.
|
||||
const varIntOverhead = 8 // conservative
|
||||
if d.maxStrLen != 0 && int64(len(d.buf)) > 2*(int64(d.maxStrLen)+varIntOverhead) {
|
||||
return 0, ErrStringLength
|
||||
}
|
||||
d.saveBuf.Write(d.buf)
|
||||
return len(p), nil
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
// errNeedMore is an internal sentinel error value that means the
|
||||
// buffer is truncated and we need to read more data before we can
|
||||
// continue parsing.
|
||||
var errNeedMore = errors.New("need more data")
|
||||
|
||||
type indexType int
|
||||
|
||||
const (
|
||||
indexedTrue indexType = iota
|
||||
indexedFalse
|
||||
indexedNever
|
||||
)
|
||||
|
||||
func (v indexType) indexed() bool { return v == indexedTrue }
|
||||
func (v indexType) sensitive() bool { return v == indexedNever }
|
||||
|
||||
// returns errNeedMore if there isn't enough data available.
|
||||
// any other error is fatal.
|
||||
// consumes d.buf iff it returns nil.
|
||||
// precondition: must be called with len(d.buf) > 0
|
||||
func (d *Decoder) parseHeaderFieldRepr() error {
|
||||
b := d.buf[0]
|
||||
switch {
|
||||
case b&128 != 0:
|
||||
// Indexed representation.
|
||||
// High bit set?
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1
|
||||
return d.parseFieldIndexed()
|
||||
case b&192 == 64:
|
||||
// 6.2.1 Literal Header Field with Incremental Indexing
|
||||
// 0b10xxxxxx: top two bits are 10
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1
|
||||
return d.parseFieldLiteral(6, indexedTrue)
|
||||
case b&240 == 0:
|
||||
// 6.2.2 Literal Header Field without Indexing
|
||||
// 0b0000xxxx: top four bits are 0000
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2
|
||||
return d.parseFieldLiteral(4, indexedFalse)
|
||||
case b&240 == 16:
|
||||
// 6.2.3 Literal Header Field never Indexed
|
||||
// 0b0001xxxx: top four bits are 0001
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3
|
||||
return d.parseFieldLiteral(4, indexedNever)
|
||||
case b&224 == 32:
|
||||
// 6.3 Dynamic Table Size Update
|
||||
// Top three bits are '001'.
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.3
|
||||
return d.parseDynamicTableSizeUpdate()
|
||||
}
|
||||
|
||||
return DecodingError{errors.New("invalid encoding")}
|
||||
}
|
||||
|
||||
// (same invariants and behavior as parseHeaderFieldRepr)
|
||||
func (d *Decoder) parseFieldIndexed() error {
|
||||
buf := d.buf
|
||||
idx, buf, err := readVarInt(7, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf, ok := d.at(idx)
|
||||
if !ok {
|
||||
return DecodingError{InvalidIndexError(idx)}
|
||||
}
|
||||
d.buf = buf
|
||||
return d.callEmit(HeaderField{Name: hf.Name, Value: hf.Value})
|
||||
}
|
||||
|
||||
// (same invariants and behavior as parseHeaderFieldRepr)
|
||||
func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
|
||||
buf := d.buf
|
||||
nameIdx, buf, err := readVarInt(n, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var hf HeaderField
|
||||
wantStr := d.emitEnabled || it.indexed()
|
||||
if nameIdx > 0 {
|
||||
ihf, ok := d.at(nameIdx)
|
||||
if !ok {
|
||||
return DecodingError{InvalidIndexError(nameIdx)}
|
||||
}
|
||||
hf.Name = ihf.Name
|
||||
} else {
|
||||
hf.Name, buf, err = d.readString(buf, wantStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
hf.Value, buf, err = d.readString(buf, wantStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.buf = buf
|
||||
if it.indexed() {
|
||||
d.dynTab.add(hf)
|
||||
}
|
||||
hf.Sensitive = it.sensitive()
|
||||
return d.callEmit(hf)
|
||||
}
|
||||
|
||||
func (d *Decoder) callEmit(hf HeaderField) error {
|
||||
if d.maxStrLen != 0 {
|
||||
if len(hf.Name) > d.maxStrLen || len(hf.Value) > d.maxStrLen {
|
||||
return ErrStringLength
|
||||
}
|
||||
}
|
||||
if d.emitEnabled {
|
||||
d.emit(hf)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// (same invariants and behavior as parseHeaderFieldRepr)
|
||||
func (d *Decoder) parseDynamicTableSizeUpdate() error {
|
||||
buf := d.buf
|
||||
size, buf, err := readVarInt(5, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if size > uint64(d.dynTab.allowedMaxSize) {
|
||||
return DecodingError{errors.New("dynamic table size update too large")}
|
||||
}
|
||||
d.dynTab.setMaxSize(uint32(size))
|
||||
d.buf = buf
|
||||
return nil
|
||||
}
|
||||
|
||||
var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
|
||||
|
||||
// readVarInt reads an unsigned variable length integer off the
|
||||
// beginning of p. n is the parameter as described in
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
|
||||
//
|
||||
// n must always be between 1 and 8.
|
||||
//
|
||||
// The returned remain buffer is either a smaller suffix of p, or err != nil.
|
||||
// The error is errNeedMore if p doesn't contain a complete integer.
|
||||
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
|
||||
if n < 1 || n > 8 {
|
||||
panic("bad n")
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, p, errNeedMore
|
||||
}
|
||||
i = uint64(p[0])
|
||||
if n < 8 {
|
||||
i &= (1 << uint64(n)) - 1
|
||||
}
|
||||
if i < (1<<uint64(n))-1 {
|
||||
return i, p[1:], nil
|
||||
}
|
||||
|
||||
origP := p
|
||||
p = p[1:]
|
||||
var m uint64
|
||||
for len(p) > 0 {
|
||||
b := p[0]
|
||||
p = p[1:]
|
||||
i += uint64(b&127) << m
|
||||
if b&128 == 0 {
|
||||
return i, p, nil
|
||||
}
|
||||
m += 7
|
||||
if m >= 63 { // TODO: proper overflow check. making this up.
|
||||
return 0, origP, errVarintOverflow
|
||||
}
|
||||
}
|
||||
return 0, origP, errNeedMore
|
||||
}
|
||||
|
||||
// readString decodes an hpack string from p.
|
||||
//
|
||||
// wantStr is whether s will be used. If false, decompression and
|
||||
// []byte->string garbage are skipped if s will be ignored
|
||||
// anyway. This does mean that huffman decoding errors for non-indexed
|
||||
// strings past the MAX_HEADER_LIST_SIZE are ignored, but the server
|
||||
// is returning an error anyway, and because they're not indexed, the error
|
||||
// won't affect the decoding state.
|
||||
func (d *Decoder) readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
|
||||
if len(p) == 0 {
|
||||
return "", p, errNeedMore
|
||||
}
|
||||
isHuff := p[0]&128 != 0
|
||||
strLen, p, err := readVarInt(7, p)
|
||||
if err != nil {
|
||||
return "", p, err
|
||||
}
|
||||
if d.maxStrLen != 0 && strLen > uint64(d.maxStrLen) {
|
||||
return "", nil, ErrStringLength
|
||||
}
|
||||
if uint64(len(p)) < strLen {
|
||||
return "", p, errNeedMore
|
||||
}
|
||||
if !isHuff {
|
||||
if wantStr {
|
||||
s = string(p[:strLen])
|
||||
}
|
||||
return s, p[strLen:], nil
|
||||
}
|
||||
|
||||
if wantStr {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset() // don't trust others
|
||||
defer bufPool.Put(buf)
|
||||
if err := huffmanDecode(buf, d.maxStrLen, p[:strLen]); err != nil {
|
||||
buf.Reset()
|
||||
return "", nil, err
|
||||
}
|
||||
s = buf.String()
|
||||
buf.Reset() // be nice to GC
|
||||
}
|
||||
return s, p[strLen:], nil
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} { return new(bytes.Buffer) },
|
||||
}
|
||||
|
||||
// HuffmanDecode decodes the string in v and writes the expanded
|
||||
// result to w, returning the number of bytes written to w and the
|
||||
// Write call's return value. At most one Write call is made.
|
||||
func HuffmanDecode(w io.Writer, v []byte) (int, error) {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer bufPool.Put(buf)
|
||||
if err := huffmanDecode(buf, 0, v); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
// HuffmanDecodeToString decodes the string in v.
|
||||
func HuffmanDecodeToString(v []byte) (string, error) {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer bufPool.Put(buf)
|
||||
if err := huffmanDecode(buf, 0, v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ErrInvalidHuffman is returned for errors found decoding
|
||||
// Huffman-encoded strings.
|
||||
var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
|
||||
|
||||
// huffmanDecode decodes v to buf.
|
||||
// If maxLen is greater than 0, attempts to write more to buf than
|
||||
// maxLen bytes will return ErrStringLength.
|
||||
func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
|
||||
n := rootHuffmanNode
|
||||
// cur is the bit buffer that has not been fed into n.
|
||||
// cbits is the number of low order bits in cur that are valid.
|
||||
// sbits is the number of bits of the symbol prefix being decoded.
|
||||
cur, cbits, sbits := uint(0), uint8(0), uint8(0)
|
||||
for _, b := range v {
|
||||
cur = cur<<8 | uint(b)
|
||||
cbits += 8
|
||||
sbits += 8
|
||||
for cbits >= 8 {
|
||||
idx := byte(cur >> (cbits - 8))
|
||||
n = n.children[idx]
|
||||
if n == nil {
|
||||
return ErrInvalidHuffman
|
||||
}
|
||||
if n.children == nil {
|
||||
if maxLen != 0 && buf.Len() == maxLen {
|
||||
return ErrStringLength
|
||||
}
|
||||
buf.WriteByte(n.sym)
|
||||
cbits -= n.codeLen
|
||||
n = rootHuffmanNode
|
||||
sbits = cbits
|
||||
} else {
|
||||
cbits -= 8
|
||||
}
|
||||
}
|
||||
}
|
||||
for cbits > 0 {
|
||||
n = n.children[byte(cur<<(8-cbits))]
|
||||
if n == nil {
|
||||
return ErrInvalidHuffman
|
||||
}
|
||||
if n.children != nil || n.codeLen > cbits {
|
||||
break
|
||||
}
|
||||
if maxLen != 0 && buf.Len() == maxLen {
|
||||
return ErrStringLength
|
||||
}
|
||||
buf.WriteByte(n.sym)
|
||||
cbits -= n.codeLen
|
||||
n = rootHuffmanNode
|
||||
sbits = cbits
|
||||
}
|
||||
if sbits > 7 {
|
||||
// Either there was an incomplete symbol, or overlong padding.
|
||||
// Both are decoding errors per RFC 7541 section 5.2.
|
||||
return ErrInvalidHuffman
|
||||
}
|
||||
if mask := uint(1<<cbits - 1); cur&mask != mask {
|
||||
// Trailing bits must be a prefix of EOS per RFC 7541 section 5.2.
|
||||
return ErrInvalidHuffman
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type node struct {
|
||||
// children is non-nil for internal nodes
|
||||
children []*node
|
||||
|
||||
// The following are only valid if children is nil:
|
||||
codeLen uint8 // number of bits that led to the output of sym
|
||||
sym byte // output symbol
|
||||
}
|
||||
|
||||
func newInternalNode() *node {
|
||||
return &node{children: make([]*node, 256)}
|
||||
}
|
||||
|
||||
var rootHuffmanNode = newInternalNode()
|
||||
|
||||
func init() {
|
||||
if len(huffmanCodes) != 256 {
|
||||
panic("unexpected size")
|
||||
}
|
||||
for i, code := range huffmanCodes {
|
||||
addDecoderNode(byte(i), code, huffmanCodeLen[i])
|
||||
}
|
||||
}
|
||||
|
||||
func addDecoderNode(sym byte, code uint32, codeLen uint8) {
|
||||
cur := rootHuffmanNode
|
||||
for codeLen > 8 {
|
||||
codeLen -= 8
|
||||
i := uint8(code >> codeLen)
|
||||
if cur.children[i] == nil {
|
||||
cur.children[i] = newInternalNode()
|
||||
}
|
||||
cur = cur.children[i]
|
||||
}
|
||||
shift := 8 - codeLen
|
||||
start, end := int(uint8(code<<shift)), int(1<<shift)
|
||||
for i := start; i < start+end; i++ {
|
||||
cur.children[i] = &node{sym: sym, codeLen: codeLen}
|
||||
}
|
||||
}
|
||||
|
||||
// AppendHuffmanString appends s, as encoded in Huffman codes, to dst
|
||||
// and returns the extended buffer.
|
||||
func AppendHuffmanString(dst []byte, s string) []byte {
|
||||
rembits := uint8(8)
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
if rembits == 8 {
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i])
|
||||
}
|
||||
|
||||
if rembits < 8 {
|
||||
// special EOS symbol
|
||||
code := uint32(0x3fffffff)
|
||||
nbits := uint8(30)
|
||||
|
||||
t := uint8(code >> (nbits - rembits))
|
||||
dst[len(dst)-1] |= t
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// HuffmanEncodeLength returns the number of bytes required to encode
|
||||
// s in Huffman codes. The result is round up to byte boundary.
|
||||
func HuffmanEncodeLength(s string) uint64 {
|
||||
n := uint64(0)
|
||||
for i := 0; i < len(s); i++ {
|
||||
n += uint64(huffmanCodeLen[s[i]])
|
||||
}
|
||||
return (n + 7) / 8
|
||||
}
|
||||
|
||||
// appendByteToHuffmanCode appends Huffman code for c to dst and
|
||||
// returns the extended buffer and the remaining bits in the last
|
||||
// element. The appending is not byte aligned and the remaining bits
|
||||
// in the last element of dst is given in rembits.
|
||||
func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) {
|
||||
code := huffmanCodes[c]
|
||||
nbits := huffmanCodeLen[c]
|
||||
|
||||
for {
|
||||
if rembits > nbits {
|
||||
t := uint8(code << (rembits - nbits))
|
||||
dst[len(dst)-1] |= t
|
||||
rembits -= nbits
|
||||
break
|
||||
}
|
||||
|
||||
t := uint8(code >> (nbits - rembits))
|
||||
dst[len(dst)-1] |= t
|
||||
|
||||
nbits -= rembits
|
||||
rembits = 8
|
||||
|
||||
if nbits == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
return dst, rembits
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package hpack
|
||||
|
||||
func pair(name, value string) HeaderField {
|
||||
return HeaderField{Name: name, Value: value}
|
||||
}
|
||||
|
||||
// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B
|
||||
var staticTable = [...]HeaderField{
|
||||
pair(":authority", ""), // index 1 (1-based)
|
||||
pair(":method", "GET"),
|
||||
pair(":method", "POST"),
|
||||
pair(":path", "/"),
|
||||
pair(":path", "/index.html"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":scheme", "https"),
|
||||
pair(":status", "200"),
|
||||
pair(":status", "204"),
|
||||
pair(":status", "206"),
|
||||
pair(":status", "304"),
|
||||
pair(":status", "400"),
|
||||
pair(":status", "404"),
|
||||
pair(":status", "500"),
|
||||
pair("accept-charset", ""),
|
||||
pair("accept-encoding", "gzip, deflate"),
|
||||
pair("accept-language", ""),
|
||||
pair("accept-ranges", ""),
|
||||
pair("accept", ""),
|
||||
pair("access-control-allow-origin", ""),
|
||||
pair("age", ""),
|
||||
pair("allow", ""),
|
||||
pair("authorization", ""),
|
||||
pair("cache-control", ""),
|
||||
pair("content-disposition", ""),
|
||||
pair("content-encoding", ""),
|
||||
pair("content-language", ""),
|
||||
pair("content-length", ""),
|
||||
pair("content-location", ""),
|
||||
pair("content-range", ""),
|
||||
pair("content-type", ""),
|
||||
pair("cookie", ""),
|
||||
pair("date", ""),
|
||||
pair("etag", ""),
|
||||
pair("expect", ""),
|
||||
pair("expires", ""),
|
||||
pair("from", ""),
|
||||
pair("host", ""),
|
||||
pair("if-match", ""),
|
||||
pair("if-modified-since", ""),
|
||||
pair("if-none-match", ""),
|
||||
pair("if-range", ""),
|
||||
pair("if-unmodified-since", ""),
|
||||
pair("last-modified", ""),
|
||||
pair("link", ""),
|
||||
pair("location", ""),
|
||||
pair("max-forwards", ""),
|
||||
pair("proxy-authenticate", ""),
|
||||
pair("proxy-authorization", ""),
|
||||
pair("range", ""),
|
||||
pair("referer", ""),
|
||||
pair("refresh", ""),
|
||||
pair("retry-after", ""),
|
||||
pair("server", ""),
|
||||
pair("set-cookie", ""),
|
||||
pair("strict-transport-security", ""),
|
||||
pair("transfer-encoding", ""),
|
||||
pair("user-agent", ""),
|
||||
pair("vary", ""),
|
||||
pair("via", ""),
|
||||
pair("www-authenticate", ""),
|
||||
}
|
||||
|
||||
var huffmanCodes = [256]uint32{
|
||||
0x1ff8,
|
||||
0x7fffd8,
|
||||
0xfffffe2,
|
||||
0xfffffe3,
|
||||
0xfffffe4,
|
||||
0xfffffe5,
|
||||
0xfffffe6,
|
||||
0xfffffe7,
|
||||
0xfffffe8,
|
||||
0xffffea,
|
||||
0x3ffffffc,
|
||||
0xfffffe9,
|
||||
0xfffffea,
|
||||
0x3ffffffd,
|
||||
0xfffffeb,
|
||||
0xfffffec,
|
||||
0xfffffed,
|
||||
0xfffffee,
|
||||
0xfffffef,
|
||||
0xffffff0,
|
||||
0xffffff1,
|
||||
0xffffff2,
|
||||
0x3ffffffe,
|
||||
0xffffff3,
|
||||
0xffffff4,
|
||||
0xffffff5,
|
||||
0xffffff6,
|
||||
0xffffff7,
|
||||
0xffffff8,
|
||||
0xffffff9,
|
||||
0xffffffa,
|
||||
0xffffffb,
|
||||
0x14,
|
||||
0x3f8,
|
||||
0x3f9,
|
||||
0xffa,
|
||||
0x1ff9,
|
||||
0x15,
|
||||
0xf8,
|
||||
0x7fa,
|
||||
0x3fa,
|
||||
0x3fb,
|
||||
0xf9,
|
||||
0x7fb,
|
||||
0xfa,
|
||||
0x16,
|
||||
0x17,
|
||||
0x18,
|
||||
0x0,
|
||||
0x1,
|
||||
0x2,
|
||||
0x19,
|
||||
0x1a,
|
||||
0x1b,
|
||||
0x1c,
|
||||
0x1d,
|
||||
0x1e,
|
||||
0x1f,
|
||||
0x5c,
|
||||
0xfb,
|
||||
0x7ffc,
|
||||
0x20,
|
||||
0xffb,
|
||||
0x3fc,
|
||||
0x1ffa,
|
||||
0x21,
|
||||
0x5d,
|
||||
0x5e,
|
||||
0x5f,
|
||||
0x60,
|
||||
0x61,
|
||||
0x62,
|
||||
0x63,
|
||||
0x64,
|
||||
0x65,
|
||||
0x66,
|
||||
0x67,
|
||||
0x68,
|
||||
0x69,
|
||||
0x6a,
|
||||
0x6b,
|
||||
0x6c,
|
||||
0x6d,
|
||||
0x6e,
|
||||
0x6f,
|
||||
0x70,
|
||||
0x71,
|
||||
0x72,
|
||||
0xfc,
|
||||
0x73,
|
||||
0xfd,
|
||||
0x1ffb,
|
||||
0x7fff0,
|
||||
0x1ffc,
|
||||
0x3ffc,
|
||||
0x22,
|
||||
0x7ffd,
|
||||
0x3,
|
||||
0x23,
|
||||
0x4,
|
||||
0x24,
|
||||
0x5,
|
||||
0x25,
|
||||
0x26,
|
||||
0x27,
|
||||
0x6,
|
||||
0x74,
|
||||
0x75,
|
||||
0x28,
|
||||
0x29,
|
||||
0x2a,
|
||||
0x7,
|
||||
0x2b,
|
||||
0x76,
|
||||
0x2c,
|
||||
0x8,
|
||||
0x9,
|
||||
0x2d,
|
||||
0x77,
|
||||
0x78,
|
||||
0x79,
|
||||
0x7a,
|
||||
0x7b,
|
||||
0x7ffe,
|
||||
0x7fc,
|
||||
0x3ffd,
|
||||
0x1ffd,
|
||||
0xffffffc,
|
||||
0xfffe6,
|
||||
0x3fffd2,
|
||||
0xfffe7,
|
||||
0xfffe8,
|
||||
0x3fffd3,
|
||||
0x3fffd4,
|
||||
0x3fffd5,
|
||||
0x7fffd9,
|
||||
0x3fffd6,
|
||||
0x7fffda,
|
||||
0x7fffdb,
|
||||
0x7fffdc,
|
||||
0x7fffdd,
|
||||
0x7fffde,
|
||||
0xffffeb,
|
||||
0x7fffdf,
|
||||
0xffffec,
|
||||
0xffffed,
|
||||
0x3fffd7,
|
||||
0x7fffe0,
|
||||
0xffffee,
|
||||
0x7fffe1,
|
||||
0x7fffe2,
|
||||
0x7fffe3,
|
||||
0x7fffe4,
|
||||
0x1fffdc,
|
||||
0x3fffd8,
|
||||
0x7fffe5,
|
||||
0x3fffd9,
|
||||
0x7fffe6,
|
||||
0x7fffe7,
|
||||
0xffffef,
|
||||
0x3fffda,
|
||||
0x1fffdd,
|
||||
0xfffe9,
|
||||
0x3fffdb,
|
||||
0x3fffdc,
|
||||
0x7fffe8,
|
||||
0x7fffe9,
|
||||
0x1fffde,
|
||||
0x7fffea,
|
||||
0x3fffdd,
|
||||
0x3fffde,
|
||||
0xfffff0,
|
||||
0x1fffdf,
|
||||
0x3fffdf,
|
||||
0x7fffeb,
|
||||
0x7fffec,
|
||||
0x1fffe0,
|
||||
0x1fffe1,
|
||||
0x3fffe0,
|
||||
0x1fffe2,
|
||||
0x7fffed,
|
||||
0x3fffe1,
|
||||
0x7fffee,
|
||||
0x7fffef,
|
||||
0xfffea,
|
||||
0x3fffe2,
|
||||
0x3fffe3,
|
||||
0x3fffe4,
|
||||
0x7ffff0,
|
||||
0x3fffe5,
|
||||
0x3fffe6,
|
||||
0x7ffff1,
|
||||
0x3ffffe0,
|
||||
0x3ffffe1,
|
||||
0xfffeb,
|
||||
0x7fff1,
|
||||
0x3fffe7,
|
||||
0x7ffff2,
|
||||
0x3fffe8,
|
||||
0x1ffffec,
|
||||
0x3ffffe2,
|
||||
0x3ffffe3,
|
||||
0x3ffffe4,
|
||||
0x7ffffde,
|
||||
0x7ffffdf,
|
||||
0x3ffffe5,
|
||||
0xfffff1,
|
||||
0x1ffffed,
|
||||
0x7fff2,
|
||||
0x1fffe3,
|
||||
0x3ffffe6,
|
||||
0x7ffffe0,
|
||||
0x7ffffe1,
|
||||
0x3ffffe7,
|
||||
0x7ffffe2,
|
||||
0xfffff2,
|
||||
0x1fffe4,
|
||||
0x1fffe5,
|
||||
0x3ffffe8,
|
||||
0x3ffffe9,
|
||||
0xffffffd,
|
||||
0x7ffffe3,
|
||||
0x7ffffe4,
|
||||
0x7ffffe5,
|
||||
0xfffec,
|
||||
0xfffff3,
|
||||
0xfffed,
|
||||
0x1fffe6,
|
||||
0x3fffe9,
|
||||
0x1fffe7,
|
||||
0x1fffe8,
|
||||
0x7ffff3,
|
||||
0x3fffea,
|
||||
0x3fffeb,
|
||||
0x1ffffee,
|
||||
0x1ffffef,
|
||||
0xfffff4,
|
||||
0xfffff5,
|
||||
0x3ffffea,
|
||||
0x7ffff4,
|
||||
0x3ffffeb,
|
||||
0x7ffffe6,
|
||||
0x3ffffec,
|
||||
0x3ffffed,
|
||||
0x7ffffe7,
|
||||
0x7ffffe8,
|
||||
0x7ffffe9,
|
||||
0x7ffffea,
|
||||
0x7ffffeb,
|
||||
0xffffffe,
|
||||
0x7ffffec,
|
||||
0x7ffffed,
|
||||
0x7ffffee,
|
||||
0x7ffffef,
|
||||
0x7fffff0,
|
||||
0x3ffffee,
|
||||
}
|
||||
|
||||
var huffmanCodeLen = [256]uint8{
|
||||
13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28,
|
||||
28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||
6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6,
|
||||
5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 8, 15, 6, 12, 10,
|
||||
13, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
||||
7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 8, 13, 19, 13, 14, 6,
|
||||
15, 5, 6, 5, 6, 5, 6, 6, 6, 5, 7, 7, 6, 6, 6, 5,
|
||||
6, 7, 6, 5, 5, 6, 7, 7, 7, 7, 7, 15, 11, 14, 13, 28,
|
||||
20, 22, 20, 20, 22, 22, 22, 23, 22, 23, 23, 23, 23, 23, 24, 23,
|
||||
24, 24, 22, 23, 24, 23, 23, 23, 23, 21, 22, 23, 22, 23, 23, 24,
|
||||
22, 21, 20, 22, 22, 23, 23, 21, 23, 22, 22, 24, 21, 22, 23, 23,
|
||||
21, 21, 22, 21, 23, 22, 23, 23, 20, 22, 22, 22, 23, 22, 22, 23,
|
||||
26, 26, 20, 19, 22, 23, 22, 25, 26, 26, 26, 27, 27, 26, 24, 25,
|
||||
19, 21, 26, 27, 27, 26, 27, 24, 21, 21, 26, 26, 28, 27, 27, 27,
|
||||
20, 24, 20, 21, 22, 21, 21, 23, 22, 22, 25, 25, 24, 24, 26, 23,
|
||||
26, 27, 26, 26, 27, 27, 27, 27, 27, 28, 27, 27, 27, 27, 27, 26,
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package http2 implements the HTTP/2 protocol.
|
||||
//
|
||||
// This package is low-level and intended to be used directly by very
|
||||
// few people. Most users will use it indirectly through the automatic
|
||||
// use by the net/http package (from Go 1.6 and later).
|
||||
// For use in earlier Go versions see ConfigureServer. (Transport support
|
||||
// requires Go 1.6 or later)
|
||||
//
|
||||
// See https://http2.github.io/ for more information on HTTP/2.
|
||||
//
|
||||
// See https://http2.golang.org/ for a test server running this code.
|
||||
//
|
||||
package http2 // import "golang.org/x/net/http2"
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/lex/httplex"
|
||||
)
|
||||
|
||||
var (
|
||||
VerboseLogs bool
|
||||
logFrameWrites bool
|
||||
logFrameReads bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
e := os.Getenv("GODEBUG")
|
||||
if strings.Contains(e, "http2debug=1") {
|
||||
VerboseLogs = true
|
||||
}
|
||||
if strings.Contains(e, "http2debug=2") {
|
||||
VerboseLogs = true
|
||||
logFrameWrites = true
|
||||
logFrameReads = true
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// ClientPreface is the string that must be sent by new
|
||||
// connections from clients.
|
||||
ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
|
||||
|
||||
// SETTINGS_MAX_FRAME_SIZE default
|
||||
// http://http2.github.io/http2-spec/#rfc.section.6.5.2
|
||||
initialMaxFrameSize = 16384
|
||||
|
||||
// NextProtoTLS is the NPN/ALPN protocol negotiated during
|
||||
// HTTP/2's TLS setup.
|
||||
NextProtoTLS = "h2"
|
||||
|
||||
// http://http2.github.io/http2-spec/#SettingValues
|
||||
initialHeaderTableSize = 4096
|
||||
|
||||
initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
|
||||
|
||||
defaultMaxReadFrameSize = 1 << 20
|
||||
)
|
||||
|
||||
var (
|
||||
clientPreface = []byte(ClientPreface)
|
||||
)
|
||||
|
||||
type streamState int
|
||||
|
||||
const (
|
||||
stateIdle streamState = iota
|
||||
stateOpen
|
||||
stateHalfClosedLocal
|
||||
stateHalfClosedRemote
|
||||
stateResvLocal
|
||||
stateResvRemote
|
||||
stateClosed
|
||||
)
|
||||
|
||||
var stateName = [...]string{
|
||||
stateIdle: "Idle",
|
||||
stateOpen: "Open",
|
||||
stateHalfClosedLocal: "HalfClosedLocal",
|
||||
stateHalfClosedRemote: "HalfClosedRemote",
|
||||
stateResvLocal: "ResvLocal",
|
||||
stateResvRemote: "ResvRemote",
|
||||
stateClosed: "Closed",
|
||||
}
|
||||
|
||||
func (st streamState) String() string {
|
||||
return stateName[st]
|
||||
}
|
||||
|
||||
// Setting is a setting parameter: which setting it is, and its value.
|
||||
type Setting struct {
|
||||
// ID is which setting is being set.
|
||||
// See http://http2.github.io/http2-spec/#SettingValues
|
||||
ID SettingID
|
||||
|
||||
// Val is the value.
|
||||
Val uint32
|
||||
}
|
||||
|
||||
func (s Setting) String() string {
|
||||
return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
|
||||
}
|
||||
|
||||
// Valid reports whether the setting is valid.
|
||||
func (s Setting) Valid() error {
|
||||
// Limits and error codes from 6.5.2 Defined SETTINGS Parameters
|
||||
switch s.ID {
|
||||
case SettingEnablePush:
|
||||
if s.Val != 1 && s.Val != 0 {
|
||||
return ConnectionError(ErrCodeProtocol)
|
||||
}
|
||||
case SettingInitialWindowSize:
|
||||
if s.Val > 1<<31-1 {
|
||||
return ConnectionError(ErrCodeFlowControl)
|
||||
}
|
||||
case SettingMaxFrameSize:
|
||||
if s.Val < 16384 || s.Val > 1<<24-1 {
|
||||
return ConnectionError(ErrCodeProtocol)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A SettingID is an HTTP/2 setting as defined in
|
||||
// http://http2.github.io/http2-spec/#iana-settings
|
||||
type SettingID uint16
|
||||
|
||||
const (
|
||||
SettingHeaderTableSize SettingID = 0x1
|
||||
SettingEnablePush SettingID = 0x2
|
||||
SettingMaxConcurrentStreams SettingID = 0x3
|
||||
SettingInitialWindowSize SettingID = 0x4
|
||||
SettingMaxFrameSize SettingID = 0x5
|
||||
SettingMaxHeaderListSize SettingID = 0x6
|
||||
)
|
||||
|
||||
var settingName = map[SettingID]string{
|
||||
SettingHeaderTableSize: "HEADER_TABLE_SIZE",
|
||||
SettingEnablePush: "ENABLE_PUSH",
|
||||
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
|
||||
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
|
||||
SettingMaxFrameSize: "MAX_FRAME_SIZE",
|
||||
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
|
||||
}
|
||||
|
||||
func (s SettingID) String() string {
|
||||
if v, ok := settingName[s]; ok {
|
||||
return v
|
||||
}
|
||||
return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
|
||||
}
|
||||
|
||||
var (
|
||||
errInvalidHeaderFieldName = errors.New("http2: invalid header field name")
|
||||
errInvalidHeaderFieldValue = errors.New("http2: invalid header field value")
|
||||
)
|
||||
|
||||
// validWireHeaderFieldName reports whether v is a valid header field
|
||||
// name (key). See httplex.ValidHeaderName for the base rules.
|
||||
//
|
||||
// Further, http2 says:
|
||||
// "Just as in HTTP/1.x, header field names are strings of ASCII
|
||||
// characters that are compared in a case-insensitive
|
||||
// fashion. However, header field names MUST be converted to
|
||||
// lowercase prior to their encoding in HTTP/2. "
|
||||
func validWireHeaderFieldName(v string) bool {
|
||||
if len(v) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range v {
|
||||
if !httplex.IsTokenRune(r) {
|
||||
return false
|
||||
}
|
||||
if 'A' <= r && r <= 'Z' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n)
|
||||
|
||||
func init() {
|
||||
for i := 100; i <= 999; i++ {
|
||||
if v := http.StatusText(i); v != "" {
|
||||
httpCodeStringCommon[i] = strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func httpCodeString(code int) string {
|
||||
if s, ok := httpCodeStringCommon[code]; ok {
|
||||
return s
|
||||
}
|
||||
return strconv.Itoa(code)
|
||||
}
|
||||
|
||||
// from pkg io
|
||||
type stringWriter interface {
|
||||
WriteString(s string) (n int, err error)
|
||||
}
|
||||
|
||||
// A gate lets two goroutines coordinate their activities.
|
||||
type gate chan struct{}
|
||||
|
||||
func (g gate) Done() { g <- struct{}{} }
|
||||
func (g gate) Wait() { <-g }
|
||||
|
||||
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
|
||||
type closeWaiter chan struct{}
|
||||
|
||||
// Init makes a closeWaiter usable.
|
||||
// It exists because so a closeWaiter value can be placed inside a
|
||||
// larger struct and have the Mutex and Cond's memory in the same
|
||||
// allocation.
|
||||
func (cw *closeWaiter) Init() {
|
||||
*cw = make(chan struct{})
|
||||
}
|
||||
|
||||
// Close marks the closeWaiter as closed and unblocks any waiters.
|
||||
func (cw closeWaiter) Close() {
|
||||
close(cw)
|
||||
}
|
||||
|
||||
// Wait waits for the closeWaiter to become closed.
|
||||
func (cw closeWaiter) Wait() {
|
||||
<-cw
|
||||
}
|
||||
|
||||
// bufferedWriter is a buffered writer that writes to w.
|
||||
// Its buffered writer is lazily allocated as needed, to minimize
|
||||
// idle memory usage with many connections.
|
||||
type bufferedWriter struct {
|
||||
w io.Writer // immutable
|
||||
bw *bufio.Writer // non-nil when data is buffered
|
||||
}
|
||||
|
||||
func newBufferedWriter(w io.Writer) *bufferedWriter {
|
||||
return &bufferedWriter{w: w}
|
||||
}
|
||||
|
||||
var bufWriterPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
// TODO: pick something better? this is a bit under
|
||||
// (3 x typical 1500 byte MTU) at least.
|
||||
return bufio.NewWriterSize(nil, 4<<10)
|
||||
},
|
||||
}
|
||||
|
||||
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
|
||||
if w.bw == nil {
|
||||
bw := bufWriterPool.Get().(*bufio.Writer)
|
||||
bw.Reset(w.w)
|
||||
w.bw = bw
|
||||
}
|
||||
return w.bw.Write(p)
|
||||
}
|
||||
|
||||
func (w *bufferedWriter) Flush() error {
|
||||
bw := w.bw
|
||||
if bw == nil {
|
||||
return nil
|
||||
}
|
||||
err := bw.Flush()
|
||||
bw.Reset(nil)
|
||||
bufWriterPool.Put(bw)
|
||||
w.bw = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func mustUint31(v int32) uint32 {
|
||||
if v < 0 || v > 2147483647 {
|
||||
panic("out of range")
|
||||
}
|
||||
return uint32(v)
|
||||
}
|
||||
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type httpError struct {
|
||||
msg string
|
||||
timeout bool
|
||||
}
|
||||
|
||||
func (e *httpError) Error() string { return e.msg }
|
||||
func (e *httpError) Timeout() bool { return e.timeout }
|
||||
func (e *httpError) Temporary() bool { return true }
|
||||
|
||||
var errTimeout error = &httpError{msg: "http2: timeout awaiting response headers", timeout: true}
|
||||
|
||||
type connectionStater interface {
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
|
||||
var sorterPool = sync.Pool{New: func() interface{} { return new(sorter) }}
|
||||
|
||||
type sorter struct {
|
||||
v []string // owned by sorter
|
||||
}
|
||||
|
||||
func (s *sorter) Len() int { return len(s.v) }
|
||||
func (s *sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] }
|
||||
func (s *sorter) Less(i, j int) bool { return s.v[i] < s.v[j] }
|
||||
|
||||
// Keys returns the sorted keys of h.
|
||||
//
|
||||
// The returned slice is only valid until s used again or returned to
|
||||
// its pool.
|
||||
func (s *sorter) Keys(h http.Header) []string {
|
||||
keys := s.v[:0]
|
||||
for k := range h {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
s.v = keys
|
||||
sort.Sort(s)
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *sorter) SortStrings(ss []string) {
|
||||
// Our sorter works on s.v, which sorter owners, so
|
||||
// stash it away while we sort the user's buffer.
|
||||
save := s.v
|
||||
s.v = ss
|
||||
sort.Sort(s)
|
||||
s.v = save
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.6
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func configureTransport(t1 *http.Transport) (*Transport, error) {
|
||||
return nil, errTransportVersion
|
||||
}
|
||||
|
||||
func transportExpectContinueTimeout(t1 *http.Transport) time.Duration {
|
||||
return 0
|
||||
|
||||
}
|
||||
|
||||
// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec.
|
||||
func isBadCipher(cipher uint16) bool {
|
||||
switch cipher {
|
||||
case tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
|
||||
// Reject cipher suites from Appendix A.
|
||||
// "This list includes those cipher suites that do not
|
||||
// offer an ephemeral key exchange and those that are
|
||||
// based on the TLS null, stream or block cipher type"
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.7
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type contextContext interface{}
|
||||
|
||||
type fakeContext struct{}
|
||||
|
||||
func (fakeContext) Done() <-chan struct{} { return nil }
|
||||
func (fakeContext) Err() error { panic("should not be called") }
|
||||
|
||||
func reqContext(r *http.Request) fakeContext {
|
||||
return fakeContext{}
|
||||
}
|
||||
|
||||
func setResponseUncompressed(res *http.Response) {
|
||||
// Nothing.
|
||||
}
|
||||
|
||||
type clientTrace struct{}
|
||||
|
||||
func requestTrace(*http.Request) *clientTrace { return nil }
|
||||
func traceGotConn(*http.Request, *ClientConn) {}
|
||||
func traceFirstResponseByte(*clientTrace) {}
|
||||
func traceWroteHeaders(*clientTrace) {}
|
||||
func traceWroteRequest(*clientTrace, error) {}
|
||||
func traceGot100Continue(trace *clientTrace) {}
|
||||
func traceWait100Continue(trace *clientTrace) {}
|
||||
|
||||
func nop() {}
|
||||
|
||||
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
|
||||
return nil, nop
|
||||
}
|
||||
|
||||
func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
|
||||
return ctx, nop
|
||||
}
|
||||
|
||||
func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
|
||||
return req
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
|
||||
// io.Pipe except there are no PipeReader/PipeWriter halves, and the
|
||||
// underlying buffer is an interface. (io.Pipe is always unbuffered)
|
||||
type pipe struct {
|
||||
mu sync.Mutex
|
||||
c sync.Cond // c.L lazily initialized to &p.mu
|
||||
b pipeBuffer
|
||||
err error // read error once empty. non-nil means closed.
|
||||
breakErr error // immediate read error (caller doesn't see rest of b)
|
||||
donec chan struct{} // closed on error
|
||||
readFn func() // optional code to run in Read before error
|
||||
}
|
||||
|
||||
type pipeBuffer interface {
|
||||
Len() int
|
||||
io.Writer
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (p *pipe) Len() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.b.Len()
|
||||
}
|
||||
|
||||
// Read waits until data is available and copies bytes
|
||||
// from the buffer into p.
|
||||
func (p *pipe) Read(d []byte) (n int, err error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.c.L == nil {
|
||||
p.c.L = &p.mu
|
||||
}
|
||||
for {
|
||||
if p.breakErr != nil {
|
||||
return 0, p.breakErr
|
||||
}
|
||||
if p.b.Len() > 0 {
|
||||
return p.b.Read(d)
|
||||
}
|
||||
if p.err != nil {
|
||||
if p.readFn != nil {
|
||||
p.readFn() // e.g. copy trailers
|
||||
p.readFn = nil // not sticky like p.err
|
||||
}
|
||||
return 0, p.err
|
||||
}
|
||||
p.c.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
var errClosedPipeWrite = errors.New("write on closed buffer")
|
||||
|
||||
// Write copies bytes from p into the buffer and wakes a reader.
|
||||
// It is an error to write more data than the buffer can hold.
|
||||
func (p *pipe) Write(d []byte) (n int, err error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.c.L == nil {
|
||||
p.c.L = &p.mu
|
||||
}
|
||||
defer p.c.Signal()
|
||||
if p.err != nil {
|
||||
return 0, errClosedPipeWrite
|
||||
}
|
||||
return p.b.Write(d)
|
||||
}
|
||||
|
||||
// CloseWithError causes the next Read (waking up a current blocked
|
||||
// Read if needed) to return the provided err after all data has been
|
||||
// read.
|
||||
//
|
||||
// The error must be non-nil.
|
||||
func (p *pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) }
|
||||
|
||||
// BreakWithError causes the next Read (waking up a current blocked
|
||||
// Read if needed) to return the provided err immediately, without
|
||||
// waiting for unread data.
|
||||
func (p *pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) }
|
||||
|
||||
// closeWithErrorAndCode is like CloseWithError but also sets some code to run
|
||||
// in the caller's goroutine before returning the error.
|
||||
func (p *pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) }
|
||||
|
||||
func (p *pipe) closeWithError(dst *error, err error, fn func()) {
|
||||
if err == nil {
|
||||
panic("err must be non-nil")
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.c.L == nil {
|
||||
p.c.L = &p.mu
|
||||
}
|
||||
defer p.c.Signal()
|
||||
if *dst != nil {
|
||||
// Already been done.
|
||||
return
|
||||
}
|
||||
p.readFn = fn
|
||||
*dst = err
|
||||
p.closeDoneLocked()
|
||||
}
|
||||
|
||||
// requires p.mu be held.
|
||||
func (p *pipe) closeDoneLocked() {
|
||||
if p.donec == nil {
|
||||
return
|
||||
}
|
||||
// Close if unclosed. This isn't racy since we always
|
||||
// hold p.mu while closing.
|
||||
select {
|
||||
case <-p.donec:
|
||||
default:
|
||||
close(p.donec)
|
||||
}
|
||||
}
|
||||
|
||||
// Err returns the error (if any) first set by BreakWithError or CloseWithError.
|
||||
func (p *pipe) Err() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.breakErr != nil {
|
||||
return p.breakErr
|
||||
}
|
||||
return p.err
|
||||
}
|
||||
|
||||
// Done returns a channel which is closed if and when this pipe is closed
|
||||
// with CloseWithError.
|
||||
func (p *pipe) Done() <-chan struct{} {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.donec == nil {
|
||||
p.donec = make(chan struct{})
|
||||
if p.err != nil || p.breakErr != nil {
|
||||
// Already hit an error.
|
||||
p.closeDoneLocked()
|
||||
}
|
||||
}
|
||||
return p.donec
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,264 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/lex/httplex"
|
||||
)
|
||||
|
||||
// writeFramer is implemented by any type that is used to write frames.
|
||||
type writeFramer interface {
|
||||
writeFrame(writeContext) error
|
||||
}
|
||||
|
||||
// writeContext is the interface needed by the various frame writer
|
||||
// types below. All the writeFrame methods below are scheduled via the
|
||||
// frame writing scheduler (see writeScheduler in writesched.go).
|
||||
//
|
||||
// This interface is implemented by *serverConn.
|
||||
//
|
||||
// TODO: decide whether to a) use this in the client code (which didn't
|
||||
// end up using this yet, because it has a simpler design, not
|
||||
// currently implementing priorities), or b) delete this and
|
||||
// make the server code a bit more concrete.
|
||||
type writeContext interface {
|
||||
Framer() *Framer
|
||||
Flush() error
|
||||
CloseConn() error
|
||||
// HeaderEncoder returns an HPACK encoder that writes to the
|
||||
// returned buffer.
|
||||
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
|
||||
}
|
||||
|
||||
// endsStream reports whether the given frame writer w will locally
|
||||
// close the stream.
|
||||
func endsStream(w writeFramer) bool {
|
||||
switch v := w.(type) {
|
||||
case *writeData:
|
||||
return v.endStream
|
||||
case *writeResHeaders:
|
||||
return v.endStream
|
||||
case nil:
|
||||
// This can only happen if the caller reuses w after it's
|
||||
// been intentionally nil'ed out to prevent use. Keep this
|
||||
// here to catch future refactoring breaking it.
|
||||
panic("endsStream called on nil writeFramer")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type flushFrameWriter struct{}
|
||||
|
||||
func (flushFrameWriter) writeFrame(ctx writeContext) error {
|
||||
return ctx.Flush()
|
||||
}
|
||||
|
||||
type writeSettings []Setting
|
||||
|
||||
func (s writeSettings) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteSettings([]Setting(s)...)
|
||||
}
|
||||
|
||||
type writeGoAway struct {
|
||||
maxStreamID uint32
|
||||
code ErrCode
|
||||
}
|
||||
|
||||
func (p *writeGoAway) writeFrame(ctx writeContext) error {
|
||||
err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
|
||||
if p.code != 0 {
|
||||
ctx.Flush() // ignore error: we're hanging up on them anyway
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
ctx.CloseConn()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type writeData struct {
|
||||
streamID uint32
|
||||
p []byte
|
||||
endStream bool
|
||||
}
|
||||
|
||||
func (w *writeData) String() string {
|
||||
return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
|
||||
}
|
||||
|
||||
func (w *writeData) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
|
||||
}
|
||||
|
||||
// handlerPanicRST is the message sent from handler goroutines when
|
||||
// the handler panics.
|
||||
type handlerPanicRST struct {
|
||||
StreamID uint32
|
||||
}
|
||||
|
||||
func (hp handlerPanicRST) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteRSTStream(hp.StreamID, ErrCodeInternal)
|
||||
}
|
||||
|
||||
func (se StreamError) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
|
||||
}
|
||||
|
||||
type writePingAck struct{ pf *PingFrame }
|
||||
|
||||
func (w writePingAck) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WritePing(true, w.pf.Data)
|
||||
}
|
||||
|
||||
type writeSettingsAck struct{}
|
||||
|
||||
func (writeSettingsAck) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteSettingsAck()
|
||||
}
|
||||
|
||||
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
|
||||
// for HTTP response headers or trailers from a server handler.
|
||||
type writeResHeaders struct {
|
||||
streamID uint32
|
||||
httpResCode int // 0 means no ":status" line
|
||||
h http.Header // may be nil
|
||||
trailers []string // if non-nil, which keys of h to write. nil means all.
|
||||
endStream bool
|
||||
|
||||
date string
|
||||
contentType string
|
||||
contentLength string
|
||||
}
|
||||
|
||||
func encKV(enc *hpack.Encoder, k, v string) {
|
||||
if VerboseLogs {
|
||||
log.Printf("http2: server encoding header %q = %q", k, v)
|
||||
}
|
||||
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
|
||||
func (w *writeResHeaders) writeFrame(ctx writeContext) error {
|
||||
enc, buf := ctx.HeaderEncoder()
|
||||
buf.Reset()
|
||||
|
||||
if w.httpResCode != 0 {
|
||||
encKV(enc, ":status", httpCodeString(w.httpResCode))
|
||||
}
|
||||
|
||||
encodeHeaders(enc, w.h, w.trailers)
|
||||
|
||||
if w.contentType != "" {
|
||||
encKV(enc, "content-type", w.contentType)
|
||||
}
|
||||
if w.contentLength != "" {
|
||||
encKV(enc, "content-length", w.contentLength)
|
||||
}
|
||||
if w.date != "" {
|
||||
encKV(enc, "date", w.date)
|
||||
}
|
||||
|
||||
headerBlock := buf.Bytes()
|
||||
if len(headerBlock) == 0 && w.trailers == nil {
|
||||
panic("unexpected empty hpack")
|
||||
}
|
||||
|
||||
// For now we're lazy and just pick the minimum MAX_FRAME_SIZE
|
||||
// that all peers must support (16KB). Later we could care
|
||||
// more and send larger frames if the peer advertised it, but
|
||||
// there's little point. Most headers are small anyway (so we
|
||||
// generally won't have CONTINUATION frames), and extra frames
|
||||
// only waste 9 bytes anyway.
|
||||
const maxFrameSize = 16384
|
||||
|
||||
first := true
|
||||
for len(headerBlock) > 0 {
|
||||
frag := headerBlock
|
||||
if len(frag) > maxFrameSize {
|
||||
frag = frag[:maxFrameSize]
|
||||
}
|
||||
headerBlock = headerBlock[len(frag):]
|
||||
endHeaders := len(headerBlock) == 0
|
||||
var err error
|
||||
if first {
|
||||
first = false
|
||||
err = ctx.Framer().WriteHeaders(HeadersFrameParam{
|
||||
StreamID: w.streamID,
|
||||
BlockFragment: frag,
|
||||
EndStream: w.endStream,
|
||||
EndHeaders: endHeaders,
|
||||
})
|
||||
} else {
|
||||
err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type write100ContinueHeadersFrame struct {
|
||||
streamID uint32
|
||||
}
|
||||
|
||||
func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error {
|
||||
enc, buf := ctx.HeaderEncoder()
|
||||
buf.Reset()
|
||||
encKV(enc, ":status", "100")
|
||||
return ctx.Framer().WriteHeaders(HeadersFrameParam{
|
||||
StreamID: w.streamID,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
})
|
||||
}
|
||||
|
||||
type writeWindowUpdate struct {
|
||||
streamID uint32 // or 0 for conn-level
|
||||
n uint32
|
||||
}
|
||||
|
||||
func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
|
||||
}
|
||||
|
||||
func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
|
||||
if keys == nil {
|
||||
sorter := sorterPool.Get().(*sorter)
|
||||
// Using defer here, since the returned keys from the
|
||||
// sorter.Keys method is only valid until the sorter
|
||||
// is returned:
|
||||
defer sorterPool.Put(sorter)
|
||||
keys = sorter.Keys(h)
|
||||
}
|
||||
for _, k := range keys {
|
||||
vv := h[k]
|
||||
k = lowerHeader(k)
|
||||
if !validWireHeaderFieldName(k) {
|
||||
// Skip it as backup paranoia. Per
|
||||
// golang.org/issue/14048, these should
|
||||
// already be rejected at a higher level.
|
||||
continue
|
||||
}
|
||||
isTE := k == "transfer-encoding"
|
||||
for _, v := range vv {
|
||||
if !httplex.ValidHeaderFieldValue(v) {
|
||||
// TODO: return an error? golang.org/issue/14048
|
||||
// For now just omit it.
|
||||
continue
|
||||
}
|
||||
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
|
||||
if isTE && v != "trailers" {
|
||||
continue
|
||||
}
|
||||
encKV(enc, k, v)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,283 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import "fmt"
|
||||
|
||||
// frameWriteMsg is a request to write a frame.
|
||||
type frameWriteMsg struct {
|
||||
// write is the interface value that does the writing, once the
|
||||
// writeScheduler (below) has decided to select this frame
|
||||
// to write. The write functions are all defined in write.go.
|
||||
write writeFramer
|
||||
|
||||
stream *stream // used for prioritization. nil for non-stream frames.
|
||||
|
||||
// done, if non-nil, must be a buffered channel with space for
|
||||
// 1 message and is sent the return value from write (or an
|
||||
// earlier error) when the frame has been written.
|
||||
done chan error
|
||||
}
|
||||
|
||||
// for debugging only:
|
||||
func (wm frameWriteMsg) String() string {
|
||||
var streamID uint32
|
||||
if wm.stream != nil {
|
||||
streamID = wm.stream.id
|
||||
}
|
||||
var des string
|
||||
if s, ok := wm.write.(fmt.Stringer); ok {
|
||||
des = s.String()
|
||||
} else {
|
||||
des = fmt.Sprintf("%T", wm.write)
|
||||
}
|
||||
return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des)
|
||||
}
|
||||
|
||||
// writeScheduler tracks pending frames to write, priorities, and decides
|
||||
// the next one to use. It is not thread-safe.
|
||||
type writeScheduler struct {
|
||||
// zero are frames not associated with a specific stream.
|
||||
// They're sent before any stream-specific freams.
|
||||
zero writeQueue
|
||||
|
||||
// maxFrameSize is the maximum size of a DATA frame
|
||||
// we'll write. Must be non-zero and between 16K-16M.
|
||||
maxFrameSize uint32
|
||||
|
||||
// sq contains the stream-specific queues, keyed by stream ID.
|
||||
// when a stream is idle, it's deleted from the map.
|
||||
sq map[uint32]*writeQueue
|
||||
|
||||
// canSend is a slice of memory that's reused between frame
|
||||
// scheduling decisions to hold the list of writeQueues (from sq)
|
||||
// which have enough flow control data to send. After canSend is
|
||||
// built, the best is selected.
|
||||
canSend []*writeQueue
|
||||
|
||||
// pool of empty queues for reuse.
|
||||
queuePool []*writeQueue
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) putEmptyQueue(q *writeQueue) {
|
||||
if len(q.s) != 0 {
|
||||
panic("queue must be empty")
|
||||
}
|
||||
ws.queuePool = append(ws.queuePool, q)
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) getEmptyQueue() *writeQueue {
|
||||
ln := len(ws.queuePool)
|
||||
if ln == 0 {
|
||||
return new(writeQueue)
|
||||
}
|
||||
q := ws.queuePool[ln-1]
|
||||
ws.queuePool = ws.queuePool[:ln-1]
|
||||
return q
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 }
|
||||
|
||||
func (ws *writeScheduler) add(wm frameWriteMsg) {
|
||||
st := wm.stream
|
||||
if st == nil {
|
||||
ws.zero.push(wm)
|
||||
} else {
|
||||
ws.streamQueue(st.id).push(wm)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) streamQueue(streamID uint32) *writeQueue {
|
||||
if q, ok := ws.sq[streamID]; ok {
|
||||
return q
|
||||
}
|
||||
if ws.sq == nil {
|
||||
ws.sq = make(map[uint32]*writeQueue)
|
||||
}
|
||||
q := ws.getEmptyQueue()
|
||||
ws.sq[streamID] = q
|
||||
return q
|
||||
}
|
||||
|
||||
// take returns the most important frame to write and removes it from the scheduler.
|
||||
// It is illegal to call this if the scheduler is empty or if there are no connection-level
|
||||
// flow control bytes available.
|
||||
func (ws *writeScheduler) take() (wm frameWriteMsg, ok bool) {
|
||||
if ws.maxFrameSize == 0 {
|
||||
panic("internal error: ws.maxFrameSize not initialized or invalid")
|
||||
}
|
||||
|
||||
// If there any frames not associated with streams, prefer those first.
|
||||
// These are usually SETTINGS, etc.
|
||||
if !ws.zero.empty() {
|
||||
return ws.zero.shift(), true
|
||||
}
|
||||
if len(ws.sq) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Next, prioritize frames on streams that aren't DATA frames (no cost).
|
||||
for id, q := range ws.sq {
|
||||
if q.firstIsNoCost() {
|
||||
return ws.takeFrom(id, q)
|
||||
}
|
||||
}
|
||||
|
||||
// Now, all that remains are DATA frames with non-zero bytes to
|
||||
// send. So pick the best one.
|
||||
if len(ws.canSend) != 0 {
|
||||
panic("should be empty")
|
||||
}
|
||||
for _, q := range ws.sq {
|
||||
if n := ws.streamWritableBytes(q); n > 0 {
|
||||
ws.canSend = append(ws.canSend, q)
|
||||
}
|
||||
}
|
||||
if len(ws.canSend) == 0 {
|
||||
return
|
||||
}
|
||||
defer ws.zeroCanSend()
|
||||
|
||||
// TODO: find the best queue
|
||||
q := ws.canSend[0]
|
||||
|
||||
return ws.takeFrom(q.streamID(), q)
|
||||
}
|
||||
|
||||
// zeroCanSend is defered from take.
|
||||
func (ws *writeScheduler) zeroCanSend() {
|
||||
for i := range ws.canSend {
|
||||
ws.canSend[i] = nil
|
||||
}
|
||||
ws.canSend = ws.canSend[:0]
|
||||
}
|
||||
|
||||
// streamWritableBytes returns the number of DATA bytes we could write
|
||||
// from the given queue's stream, if this stream/queue were
|
||||
// selected. It is an error to call this if q's head isn't a
|
||||
// *writeData.
|
||||
func (ws *writeScheduler) streamWritableBytes(q *writeQueue) int32 {
|
||||
wm := q.head()
|
||||
ret := wm.stream.flow.available() // max we can write
|
||||
if ret == 0 {
|
||||
return 0
|
||||
}
|
||||
if int32(ws.maxFrameSize) < ret {
|
||||
ret = int32(ws.maxFrameSize)
|
||||
}
|
||||
if ret == 0 {
|
||||
panic("internal error: ws.maxFrameSize not initialized or invalid")
|
||||
}
|
||||
wd := wm.write.(*writeData)
|
||||
if len(wd.p) < int(ret) {
|
||||
ret = int32(len(wd.p))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) takeFrom(id uint32, q *writeQueue) (wm frameWriteMsg, ok bool) {
|
||||
wm = q.head()
|
||||
// If the first item in this queue costs flow control tokens
|
||||
// and we don't have enough, write as much as we can.
|
||||
if wd, ok := wm.write.(*writeData); ok && len(wd.p) > 0 {
|
||||
allowed := wm.stream.flow.available() // max we can write
|
||||
if allowed == 0 {
|
||||
// No quota available. Caller can try the next stream.
|
||||
return frameWriteMsg{}, false
|
||||
}
|
||||
if int32(ws.maxFrameSize) < allowed {
|
||||
allowed = int32(ws.maxFrameSize)
|
||||
}
|
||||
// TODO: further restrict the allowed size, because even if
|
||||
// the peer says it's okay to write 16MB data frames, we might
|
||||
// want to write smaller ones to properly weight competing
|
||||
// streams' priorities.
|
||||
|
||||
if len(wd.p) > int(allowed) {
|
||||
wm.stream.flow.take(allowed)
|
||||
chunk := wd.p[:allowed]
|
||||
wd.p = wd.p[allowed:]
|
||||
// Make up a new write message of a valid size, rather
|
||||
// than shifting one off the queue.
|
||||
return frameWriteMsg{
|
||||
stream: wm.stream,
|
||||
write: &writeData{
|
||||
streamID: wd.streamID,
|
||||
p: chunk,
|
||||
// even if the original had endStream set, there
|
||||
// arebytes remaining because len(wd.p) > allowed,
|
||||
// so we know endStream is false:
|
||||
endStream: false,
|
||||
},
|
||||
// our caller is blocking on the final DATA frame, not
|
||||
// these intermediates, so no need to wait:
|
||||
done: nil,
|
||||
}, true
|
||||
}
|
||||
wm.stream.flow.take(int32(len(wd.p)))
|
||||
}
|
||||
|
||||
q.shift()
|
||||
if q.empty() {
|
||||
ws.putEmptyQueue(q)
|
||||
delete(ws.sq, id)
|
||||
}
|
||||
return wm, true
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) forgetStream(id uint32) {
|
||||
q, ok := ws.sq[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ws.sq, id)
|
||||
|
||||
// But keep it for others later.
|
||||
for i := range q.s {
|
||||
q.s[i] = frameWriteMsg{}
|
||||
}
|
||||
q.s = q.s[:0]
|
||||
ws.putEmptyQueue(q)
|
||||
}
|
||||
|
||||
type writeQueue struct {
|
||||
s []frameWriteMsg
|
||||
}
|
||||
|
||||
// streamID returns the stream ID for a non-empty stream-specific queue.
|
||||
func (q *writeQueue) streamID() uint32 { return q.s[0].stream.id }
|
||||
|
||||
func (q *writeQueue) empty() bool { return len(q.s) == 0 }
|
||||
|
||||
func (q *writeQueue) push(wm frameWriteMsg) {
|
||||
q.s = append(q.s, wm)
|
||||
}
|
||||
|
||||
// head returns the next item that would be removed by shift.
|
||||
func (q *writeQueue) head() frameWriteMsg {
|
||||
if len(q.s) == 0 {
|
||||
panic("invalid use of queue")
|
||||
}
|
||||
return q.s[0]
|
||||
}
|
||||
|
||||
func (q *writeQueue) shift() frameWriteMsg {
|
||||
if len(q.s) == 0 {
|
||||
panic("invalid use of queue")
|
||||
}
|
||||
wm := q.s[0]
|
||||
// TODO: less copy-happy queue.
|
||||
copy(q.s, q.s[1:])
|
||||
q.s[len(q.s)-1] = frameWriteMsg{}
|
||||
q.s = q.s[:len(q.s)-1]
|
||||
return wm
|
||||
}
|
||||
|
||||
func (q *writeQueue) firstIsNoCost() bool {
|
||||
if df, ok := q.s[0].write.(*writeData); ok {
|
||||
return len(df.p) == 0
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,312 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package httplex contains rules around lexical matters of various
|
||||
// HTTP-related specifications.
|
||||
//
|
||||
// This package is shared by the standard library (which vendors it)
|
||||
// and x/net/http2. It comes with no API stability promise.
|
||||
package httplex
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var isTokenTable = [127]bool{
|
||||
'!': true,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'W': true,
|
||||
'V': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'|': true,
|
||||
'~': true,
|
||||
}
|
||||
|
||||
func IsTokenRune(r rune) bool {
|
||||
i := int(r)
|
||||
return i < len(isTokenTable) && isTokenTable[i]
|
||||
}
|
||||
|
||||
func isNotToken(r rune) bool {
|
||||
return !IsTokenRune(r)
|
||||
}
|
||||
|
||||
// HeaderValuesContainsToken reports whether any string in values
|
||||
// contains the provided token, ASCII case-insensitively.
|
||||
func HeaderValuesContainsToken(values []string, token string) bool {
|
||||
for _, v := range values {
|
||||
if headerValueContainsToken(v, token) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isOWS reports whether b is an optional whitespace byte, as defined
|
||||
// by RFC 7230 section 3.2.3.
|
||||
func isOWS(b byte) bool { return b == ' ' || b == '\t' }
|
||||
|
||||
// trimOWS returns x with all optional whitespace removes from the
|
||||
// beginning and end.
|
||||
func trimOWS(x string) string {
|
||||
// TODO: consider using strings.Trim(x, " \t") instead,
|
||||
// if and when it's fast enough. See issue 10292.
|
||||
// But this ASCII-only code will probably always beat UTF-8
|
||||
// aware code.
|
||||
for len(x) > 0 && isOWS(x[0]) {
|
||||
x = x[1:]
|
||||
}
|
||||
for len(x) > 0 && isOWS(x[len(x)-1]) {
|
||||
x = x[:len(x)-1]
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// headerValueContainsToken reports whether v (assumed to be a
|
||||
// 0#element, in the ABNF extension described in RFC 7230 section 7)
|
||||
// contains token amongst its comma-separated tokens, ASCII
|
||||
// case-insensitively.
|
||||
func headerValueContainsToken(v string, token string) bool {
|
||||
v = trimOWS(v)
|
||||
if comma := strings.IndexByte(v, ','); comma != -1 {
|
||||
return tokenEqual(trimOWS(v[:comma]), token) || headerValueContainsToken(v[comma+1:], token)
|
||||
}
|
||||
return tokenEqual(v, token)
|
||||
}
|
||||
|
||||
// lowerASCII returns the ASCII lowercase version of b.
|
||||
func lowerASCII(b byte) byte {
|
||||
if 'A' <= b && b <= 'Z' {
|
||||
return b + ('a' - 'A')
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively.
|
||||
func tokenEqual(t1, t2 string) bool {
|
||||
if len(t1) != len(t2) {
|
||||
return false
|
||||
}
|
||||
for i, b := range t1 {
|
||||
if b >= utf8.RuneSelf {
|
||||
// No UTF-8 or non-ASCII allowed in tokens.
|
||||
return false
|
||||
}
|
||||
if lowerASCII(byte(b)) != lowerASCII(t2[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isLWS reports whether b is linear white space, according
|
||||
// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
|
||||
// LWS = [CRLF] 1*( SP | HT )
|
||||
func isLWS(b byte) bool { return b == ' ' || b == '\t' }
|
||||
|
||||
// isCTL reports whether b is a control byte, according
|
||||
// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
|
||||
// CTL = <any US-ASCII control character
|
||||
// (octets 0 - 31) and DEL (127)>
|
||||
func isCTL(b byte) bool {
|
||||
const del = 0x7f // a CTL
|
||||
return b < ' ' || b == del
|
||||
}
|
||||
|
||||
// ValidHeaderFieldName reports whether v is a valid HTTP/1.x header name.
|
||||
// HTTP/2 imposes the additional restriction that uppercase ASCII
|
||||
// letters are not allowed.
|
||||
//
|
||||
// RFC 7230 says:
|
||||
// header-field = field-name ":" OWS field-value OWS
|
||||
// field-name = token
|
||||
// token = 1*tchar
|
||||
// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
|
||||
// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
|
||||
func ValidHeaderFieldName(v string) bool {
|
||||
if len(v) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range v {
|
||||
if !IsTokenRune(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidHostHeader reports whether h is a valid host header.
|
||||
func ValidHostHeader(h string) bool {
|
||||
// The latest spec is actually this:
|
||||
//
|
||||
// http://tools.ietf.org/html/rfc7230#section-5.4
|
||||
// Host = uri-host [ ":" port ]
|
||||
//
|
||||
// Where uri-host is:
|
||||
// http://tools.ietf.org/html/rfc3986#section-3.2.2
|
||||
//
|
||||
// But we're going to be much more lenient for now and just
|
||||
// search for any byte that's not a valid byte in any of those
|
||||
// expressions.
|
||||
for i := 0; i < len(h); i++ {
|
||||
if !validHostByte[h[i]] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// See the validHostHeader comment.
|
||||
var validHostByte = [256]bool{
|
||||
'0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
|
||||
'8': true, '9': true,
|
||||
|
||||
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
|
||||
'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
|
||||
'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
|
||||
'y': true, 'z': true,
|
||||
|
||||
'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
|
||||
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
|
||||
'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
|
||||
'Y': true, 'Z': true,
|
||||
|
||||
'!': true, // sub-delims
|
||||
'$': true, // sub-delims
|
||||
'%': true, // pct-encoded (and used in IPv6 zones)
|
||||
'&': true, // sub-delims
|
||||
'(': true, // sub-delims
|
||||
')': true, // sub-delims
|
||||
'*': true, // sub-delims
|
||||
'+': true, // sub-delims
|
||||
',': true, // sub-delims
|
||||
'-': true, // unreserved
|
||||
'.': true, // unreserved
|
||||
':': true, // IPv6address + Host expression's optional port
|
||||
';': true, // sub-delims
|
||||
'=': true, // sub-delims
|
||||
'[': true,
|
||||
'\'': true, // sub-delims
|
||||
']': true,
|
||||
'_': true, // unreserved
|
||||
'~': true, // unreserved
|
||||
}
|
||||
|
||||
// ValidHeaderFieldValue reports whether v is a valid "field-value" according to
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 :
|
||||
//
|
||||
// message-header = field-name ":" [ field-value ]
|
||||
// field-value = *( field-content | LWS )
|
||||
// field-content = <the OCTETs making up the field-value
|
||||
// and consisting of either *TEXT or combinations
|
||||
// of token, separators, and quoted-string>
|
||||
//
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 :
|
||||
//
|
||||
// TEXT = <any OCTET except CTLs,
|
||||
// but including LWS>
|
||||
// LWS = [CRLF] 1*( SP | HT )
|
||||
// CTL = <any US-ASCII control character
|
||||
// (octets 0 - 31) and DEL (127)>
|
||||
//
|
||||
// RFC 7230 says:
|
||||
// field-value = *( field-content / obs-fold )
|
||||
// obj-fold = N/A to http2, and deprecated
|
||||
// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
|
||||
// field-vchar = VCHAR / obs-text
|
||||
// obs-text = %x80-FF
|
||||
// VCHAR = "any visible [USASCII] character"
|
||||
//
|
||||
// http2 further says: "Similarly, HTTP/2 allows header field values
|
||||
// that are not valid. While most of the values that can be encoded
|
||||
// will not alter header field parsing, carriage return (CR, ASCII
|
||||
// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII
|
||||
// 0x0) might be exploited by an attacker if they are translated
|
||||
// verbatim. Any request or response that contains a character not
|
||||
// permitted in a header field value MUST be treated as malformed
|
||||
// (Section 8.1.2.6). Valid characters are defined by the
|
||||
// field-content ABNF rule in Section 3.2 of [RFC7230]."
|
||||
//
|
||||
// This function does not (yet?) properly handle the rejection of
|
||||
// strings that begin or end with SP or HTAB.
|
||||
func ValidHeaderFieldValue(v string) bool {
|
||||
for i := 0; i < len(v); i++ {
|
||||
b := v[i]
|
||||
if isCTL(b) && !isLWS(b) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -834,6 +834,24 @@
|
|||
"revision": "9f2c271364b418388d150f9c12ecbf12794095c1",
|
||||
"revisionTime": "2016-07-26T08:08:57Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "g4lA8FNetHh8as26V8cLl5EGy78=",
|
||||
"path": "golang.org/x/net/http2",
|
||||
"revision": "07b51741c1d6423d4a6abab1c49940ec09cb1aaf",
|
||||
"revisionTime": "2016-08-11T10:50:59Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "EYNaHp7XdLWRydUCE0amEkKAtgk=",
|
||||
"path": "golang.org/x/net/http2/hpack",
|
||||
"revision": "07b51741c1d6423d4a6abab1c49940ec09cb1aaf",
|
||||
"revisionTime": "2016-08-11T10:50:59Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "yhndhWXMs/VSEDLks4dNyFMQStA=",
|
||||
"path": "golang.org/x/net/lex/httplex",
|
||||
"revision": "07b51741c1d6423d4a6abab1c49940ec09cb1aaf",
|
||||
"revisionTime": "2016-08-11T10:50:59Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "8GUYZXntrAzpG1cThk8/TtR3kj4=",
|
||||
"path": "golang.org/x/oauth2",
|
||||
|
|
|
@ -27,10 +27,7 @@ The following table describes them:
|
|||
<td><tt>VAULT_ADDR</tt></td>
|
||||
<td>The address of the Vault server.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><tt>VAULT_ADVERTISE_ADDR</tt></td>
|
||||
<td>The advertised address of the server to use for client request forwarding when running in High Availability mode.</td>
|
||||
<tr>
|
||||
<tr>
|
||||
<td><tt>VAULT_CACERT</tt></td>
|
||||
<td>Path to a PEM-encoded CA cert file to use to verify the Vault server SSL certificate.</td>
|
||||
</tr>
|
||||
|
@ -46,10 +43,18 @@ The following table describes them:
|
|||
<td><tt>VAULT_CLIENT_KEY</tt></td>
|
||||
<td>Path to an unencrypted PEM-encoded private key matching the client certificate.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><tt>VAULT_CLUSTER_ADDR</tt></td>
|
||||
<td>The address that should be used for other cluster members to connect to this node when in High Availability mode.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><tt>VAULT_MAX_RETRIES</tt></td>
|
||||
<td>The maximum number of retries when a `5xx` error code is encountered. Default is `2`, for three total tries; set to `0` or less to disable retrying.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><tt>VAULT_REDIRECT_ADDR</tt></td>
|
||||
<td>The address that should be used when clients are redirected to this node when in High Availability mode.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><tt>VAULT_SKIP_VERIFY</tt></td>
|
||||
<td>If set, do not verify Vault's presented certificate before communicating with it. Setting this variable is not recommended except during testing.</td>
|
||||
|
|
|
@ -8,25 +8,138 @@ description: |-
|
|||
|
||||
# High Availability Mode (HA)
|
||||
|
||||
Vault supports multi-server mode for high availability. This mode protects
|
||||
Vault supports a multi-server mode for high availability. This mode protects
|
||||
against outages by running multiple Vault servers. High availability mode
|
||||
is automatically enabled when using a storage backend that supports it.
|
||||
is automatically enabled when using a data store that supports it.
|
||||
|
||||
You can tell if a backend supports high availability mode ("HA") by
|
||||
starting the server and seeing if "(HA available)" is outputted next to
|
||||
the backend information. If it is, then HA will begin happening automatically.
|
||||
You can tell if a data store supports high availability mode ("HA") by starting
|
||||
the server and seeing if "(HA available)" is output next to the data store
|
||||
information. If it is, then Vault will automatically use HA mode. This
|
||||
information is also available on the
|
||||
[Configuration](https://www.vaultproject.io/docs/config/index.html) page.
|
||||
|
||||
To be highly available, Vault elects a leader and does request forwarding to
|
||||
the leader. Due to this architecture, HA does not enable increased scalability.
|
||||
In general, the bottleneck of Vault is the storage backend itself, not
|
||||
Vault core. For example: to increase scalability of Vault with Consul, you
|
||||
would scale Consul instead of Vault.
|
||||
To be highly available, one of the Vault server nodes grabs a lock within the
|
||||
data store. The successful server node then becomes the active node; all other
|
||||
nodes become standby nodes. At this point, if the standby nodes receive a
|
||||
request, they will either forward the request or redirect the client depending
|
||||
on the current configuration and state of the cluster -- see the sections below
|
||||
for details. Due to this architecture, HA does not enable increased
|
||||
scalability. In general, the bottleneck of Vault is the data store itself, not
|
||||
Vault core. For example: to increase the scalability of Vault with Consul, you
|
||||
would generally scale Consul instead of Vault.
|
||||
|
||||
In addition to using a backend that supports HA, you have to configure
|
||||
Vault with an _advertise address_. This is the address that Vault advertises
|
||||
to other Vault servers in the cluster for request forwarding. By default,
|
||||
Vault will use the first private IP address it finds, but you can override
|
||||
this to any address you want.
|
||||
The sections below explain the server communication pattens and each type of
|
||||
request handling in more detail. At a minimum, the requirements for redirection
|
||||
mode must be met for an HA cluster to work successfully.
|
||||
|
||||
## Server-to-Server Communication
|
||||
|
||||
Both methods of request handling rely on the active node advertising
|
||||
information about itself to the other nodes. Rather than over the network, this
|
||||
communication takes place within Vault's encrypted storage; the active node
|
||||
writes this information and unsealed standby Vault nodes can read it.
|
||||
|
||||
For the client redirection method, this is the extent of server-to-server
|
||||
communication -- no direct communication with only encrypted entries in the
|
||||
data store used to transfer state.
|
||||
|
||||
For the request forwarding method, the servers need direct communication with
|
||||
each other. In order to perform this securely, the active node also advertises,
|
||||
via the encrypted data store entry, a newly-generated private key (ECDSA-P521)
|
||||
and a newly-generated self-signed certificate designated for client and server
|
||||
authentication. Each standby uses the private key and certificate to open a
|
||||
mutually-authenticated TLS 1.2 connection to the active node via the advertised
|
||||
cluster address. When client requests come in, the requests are serialized,
|
||||
sent over this TLS-protected communication channel, and acted upon by the
|
||||
active node. The active node then returns a response to the standby, which
|
||||
sends the response back to the requesting client.
|
||||
|
||||
## Client Redirection
|
||||
|
||||
This is currently the only mode enabled by default. When a standby node
|
||||
receives a request, it will redirect the client using a `307` status code to
|
||||
the _active node's_ redirect address.
|
||||
|
||||
This is also the fallback method used when request forwarding is turned off or
|
||||
there is an error performing the forwarding. As such, a redirect address is
|
||||
always required for all HA setups.
|
||||
|
||||
Some HA data store drivers can autodetect the redirect address, but it is often
|
||||
necessary to configure it manually via setting a value in the `backend`
|
||||
configuration block (or `ha_backend` if using split data/HA mode). The key for
|
||||
this value is `redirect_addr` and the value can also be specified by the
|
||||
`VAULT_REDIRECT_ADDR` environment variable, which takes precedence.
|
||||
|
||||
What the `redirect_addr` value should be set to depends on how Vault is set up.
|
||||
There are two common scenarios: Vault servers accessed directly by clients, and
|
||||
Vault servers accessed via a load balancer.
|
||||
|
||||
In both cases, the `redirect_addr` should be a full URL including scheme
|
||||
(`http`/`https`), not simply an IP address and port.
|
||||
|
||||
### Direct Access
|
||||
|
||||
When clients are able to access Vault directly, the `redirect_addr` for each
|
||||
node should be that node's address. For instance, if there are two Vault nodes
|
||||
`A` (accessed via `https://a.vault.mycompany.com`) and `B` (accessed via
|
||||
`https://b.vault.mycompany.com`), node `A` would set its `redirect_addr` to
|
||||
`https://a.vault.mycompany.com` and node `B` would set its `redirect_addr` to
|
||||
`https://b.vault.mycompany.com`.
|
||||
|
||||
This way, when `A` is the active node, any requests received by node `B` will
|
||||
cause it to redirect the client to node `A`'s `redirect_addr` at
|
||||
`https://a.vault.mycompany.com`, and vice-versa.
|
||||
|
||||
### Behind Load Balancers
|
||||
|
||||
Sometimes clients use load balancers as an initial method to access one of the
|
||||
Vault servers, but actually have direct access to each Vault node. In this
|
||||
case, the Vault servers should actually be set up as described in the above
|
||||
section, since for redirection purposes the clients have direct access.
|
||||
|
||||
However, if the only access to the Vault servers is via the load balancer, the
|
||||
`redirect_addr` on each node should be the same: the address of the load
|
||||
balancer. Clients that reach a standby node will be redirected back to the load
|
||||
balancer; at that point hopefully the load balancer's configuration will have
|
||||
been updated to know the address of the current leader. This can cause a
|
||||
redirect loop and as such is not a recommended setup when it can be avoided.
|
||||
|
||||
## Request Forwarding
|
||||
|
||||
Request forwarding is in beta in 0.6.1 and disabled by default; in a future
|
||||
release, it will be enabled by default. To enable request forwarding on a 0.6.1
|
||||
server, set the value of the key `disable_clustering` to `"false"` (note the
|
||||
quotes) in the `backend` block (or `ha_backend` block if using split data/HA
|
||||
backends).
|
||||
|
||||
If request forwarding is enabled, clients can still force the older/fallback
|
||||
redirection behavior if desired by setting the `X-Vault-No-Request-Forwarding`
|
||||
header to any non-empty value.
|
||||
|
||||
Successful cluster setup requires a few configuration parameters, although some
|
||||
can be automatically determined.
|
||||
|
||||
### Per-Node Cluster Listener Addresses
|
||||
|
||||
Each `listener` block in Vault's configuration file contains an `address` value
|
||||
on which Vault listens for requests. Similarly, each `listener` block can
|
||||
contain a `cluster_address` on which Vault listens for server-to-server cluster
|
||||
requests. If this value is not set, its IP address will be automatically set to
|
||||
same as the `address` value, and its port will be automatically set to the same
|
||||
as the `address` value plus one (so by default, port `8201`).
|
||||
|
||||
### Per-Node Cluster Address
|
||||
|
||||
Similar to the `redirect_addr`, this is the value that each node, if active,
|
||||
should advertise to the standbys to use for server-to-server communications,
|
||||
and lives in the `backend` (or `ha_backend`) block. On each node, this should
|
||||
be set to a host name or IP address that a standby can use to reach one of that
|
||||
node's `cluster_address` values set in the `listener` blocks, including port.
|
||||
(Note that this will always be forced to `https` since only TLS connections are
|
||||
used between servers.)
|
||||
|
||||
This value can also be specified by the `VAULT_CLUSTER_ADDR` environment
|
||||
variable, which takes precedence.
|
||||
|
||||
## Backend Support
|
||||
|
||||
|
@ -37,6 +150,6 @@ including Consul, ZooKeeper and etcd. These may change over time, and the
|
|||
The Consul backend is the recommended HA backend, as it is used in production
|
||||
by HashiCorp and its customers with commercial support.
|
||||
|
||||
If you're interested in implementing another backend or adding HA support
|
||||
to another backend, we'd love your contributions. Adding HA support
|
||||
requires implementing the `physical.HABackend` interface for the storage backend.
|
||||
If you're interested in implementing another backend or adding HA support to
|
||||
another backend, we'd love your contributions. Adding HA support requires
|
||||
implementing the `physical.HABackend` interface for the storage backend.
|
||||
|
|
|
@ -96,6 +96,11 @@ The supported options are:
|
|||
* `address` (optional) - The address to bind to for listening. This
|
||||
defaults to "127.0.0.1:8200".
|
||||
|
||||
* `cluster_address` (optional) - The address to bind to for cluster
|
||||
server-to-server requests. This defaults to one port higher than the
|
||||
value of `address`, so with the default value of `address`, this would be
|
||||
"127.0.0.1:8201".
|
||||
|
||||
* `tls_disable` (optional) - If true, then TLS will be disabled.
|
||||
This will parse as boolean value, and can be set to "0", "no",
|
||||
"false", "1", "yes", or "true". This is an opt-in; Vault assumes
|
||||
|
@ -213,19 +218,26 @@ to help you, but may refer you to the backend author.
|
|||
This backend does not support HA.
|
||||
|
||||
|
||||
#### Common Backend Options
|
||||
#### High Availability Options
|
||||
|
||||
All backends support the following options:
|
||||
All HA backends support the following options. These are discussed in much more
|
||||
detail in the [High Availability concepts
|
||||
page](https://www.vaultproject.io/docs/concepts/ha.html).
|
||||
|
||||
* `advertise_addr` (optional) - For backends that support HA, this
|
||||
is the address to advertise to other Vault servers in the cluster for
|
||||
request forwarding. As an example, if a cluster contains nodes A, B, and C,
|
||||
node A should set it to the address that B and C should redirect client
|
||||
nodes to when A is the active node and B and C are standby nodes. This may
|
||||
be the same address across nodes if using a load balancer or service
|
||||
discovery. Most HA backends will attempt to determine the advertise address
|
||||
if not provided. This can also be overridden via the `VAULT_ADVERTISE_ADDR`
|
||||
environment variable.
|
||||
* `redirect_addr` (optional) - This is the address to advertise to other
|
||||
Vault servers in the cluster for client redirection. This can also be
|
||||
set via the `VAULT_REDIRECT_ADDR` environment variable, which takes
|
||||
precedence.
|
||||
|
||||
* `cluster_addr` (optional) - This is the address to advertise to other Vault
|
||||
servers in the cluster for request forwarding. This can also be set via the
|
||||
`VAULT_CLUSTER_ADDR` environment variable, which takes precedence.
|
||||
|
||||
* `disable_clustering` (optional) - This controls whether clustering features
|
||||
(currently, request forwarding) are enabled. Setting this on a node will
|
||||
disable these features _when that node is the active node_. In 0.6.1 this
|
||||
is `"true"` (note the quotes) by default, but will become `"false"` by
|
||||
default in the next release.
|
||||
|
||||
#### Backend Reference: Consul
|
||||
|
||||
|
@ -248,7 +260,7 @@ For Consul, the following options are supported:
|
|||
* `service` (optional) - The name of the service to register with Consul.
|
||||
Defaults to "vault".
|
||||
|
||||
* `service-tags` (optional) - Comma separated list of tags that are to be
|
||||
* `service_tags` (optional) - Comma separated list of tags that are to be
|
||||
applied to the service that gets registered with Consul.
|
||||
|
||||
* `token` (optional) - An access token to use to write data to Consul.
|
||||
|
|
|
@ -11,6 +11,13 @@ description: |-
|
|||
This page contains the list of breaking changes for Vault 0.6.1. Please read it
|
||||
carefully.
|
||||
|
||||
## Standby Nodes Must Be 0.6.1 As Well
|
||||
|
||||
Once an active node is running 0.6.1, only standby nodes running 0.6.1+ will be
|
||||
able to form an HA cluster. If following our [general upgrade
|
||||
instructions](https://www.vaultproject.io/docs/install/upgrade.html) this will
|
||||
not be an issue.
|
||||
|
||||
## Root Token Creation Restrictions
|
||||
|
||||
Root tokens (tokens with the `root` policy) can no longer be created except by
|
||||
|
|
|
@ -25,9 +25,16 @@ upgrade notes.
|
|||
|
||||
## HA Installations
|
||||
|
||||
This is our recommended upgrade procedure, and the procedure we use internally at HashiCorp. However, you should consider how to apply these steps to your particular setup since HA setups can differ on whether a load balancer is in use, what addresses clients are being given to connect to Vault (standby + leader, leader-only, or discovered via service discovery), etc.
|
||||
This is our recommended upgrade procedure, and the procedure we use internally
|
||||
at HashiCorp. However, you should consider how to apply these steps to your
|
||||
particular setup since HA setups can differ on whether a load balancer is in
|
||||
use, what addresses clients are being given to connect to Vault (standby +
|
||||
leader, leader-only, or discovered via service discovery), etc.
|
||||
|
||||
Please note that Vault does not support true zero-downtime upgrades, but with proper upgrade procedure the downtime should be very short (a few hundred milliseconds to a second depending on how the speed of access to the storage backend).
|
||||
Please note that Vault does not support true zero-downtime upgrades, but with
|
||||
proper upgrade procedure the downtime should be very short (a few hundred
|
||||
milliseconds to a second depending on how the speed of access to the storage
|
||||
backend).
|
||||
|
||||
Perform these steps on each standby:
|
||||
|
||||
|
@ -36,7 +43,9 @@ Perform these steps on each standby:
|
|||
3. Start the standby node
|
||||
4. Unseal the standby node
|
||||
|
||||
At this point all standby nodes will be upgraded and ready to take over. The upgrade will not be complete until one of the upgraded standby nodes takes over active duty. To do this:
|
||||
At this point all standby nodes will be upgraded and ready to take over. The
|
||||
upgrade will not be complete until one of the upgraded standby nodes takes over
|
||||
active duty. To do this:
|
||||
|
||||
1. Properly shut down the remaining (active) node. Note: it is _**very
|
||||
important**_ that you shut the node down properly. This causes the HA lock to
|
||||
|
|
Loading…
Reference in New Issue