Merge pull request #4156 from hashicorp/enterprise-coexistence

Enterprise/Licensing Cleanup
This commit is contained in:
Matt Keeler 2018-06-05 10:50:32 -04:00 committed by GitHub
commit e043621dd3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 189 additions and 36 deletions

1
.gitignore vendored
View file

@ -5,6 +5,7 @@
*.swp *.swp
*.test *.test
.DS_Store .DS_Store
.fseventsd
.vagrant/ .vagrant/
/pkg /pkg
Thumbs.db Thumbs.db

View file

@ -73,6 +73,7 @@ type delegate interface {
SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error
Shutdown() error Shutdown() error
Stats() map[string]map[string]string Stats() map[string]map[string]string
enterpriseDelegate
} }
// notifier is called after a successful JoinLAN. // notifier is called after a successful JoinLAN.

View file

@ -72,6 +72,9 @@ type Client struct {
shutdown bool shutdown bool
shutdownCh chan struct{} shutdownCh chan struct{}
shutdownLock sync.Mutex shutdownLock sync.Mutex
// embedded struct to hold all the enterprise specific data
EnterpriseClient
} }
// NewClient is used to construct a new Consul client from the // NewClient is used to construct a new Consul client from the
@ -131,6 +134,11 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) {
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
} }
if err := c.initEnterprise(); err != nil {
c.Shutdown()
return nil, err
}
// Initialize the LAN Serf // Initialize the LAN Serf
c.serf, err = c.setupSerf(config.SerfLANConfig, c.serf, err = c.setupSerf(config.SerfLANConfig,
c.eventCh, serfLANSnapshot) c.eventCh, serfLANSnapshot)
@ -147,6 +155,11 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) {
// handlers depend on the router and the router depends on Serf. // handlers depend on the router and the router depends on Serf.
go c.lanEventHandler() go c.lanEventHandler()
if err := c.startEnterprise(); err != nil {
c.Shutdown()
return nil, err
}
return c, nil return c, nil
} }
@ -342,6 +355,17 @@ func (c *Client) Stats() map[string]map[string]string {
"serf_lan": c.serf.Stats(), "serf_lan": c.serf.Stats(),
"runtime": runtimeStats(), "runtime": runtimeStats(),
} }
for outerKey, outerValue := range c.enterpriseStats() {
if _, ok := stats[outerKey]; ok {
for innerKey, innerValue := range outerValue {
stats[outerKey][innerKey] = innerValue
}
} else {
stats[outerKey] = outerValue
}
}
return stats return stats
} }

View file

@ -135,6 +135,8 @@ func (c *Client) localEvent(event serf.UserEvent) {
c.config.UserEventHandler(event) c.config.UserEventHandler(event)
} }
default: default:
c.logger.Printf("[WARN] consul: Unhandled local event: %v", event) if !c.handleEnterpriseUserEvents(event) {
c.logger.Printf("[WARN] consul: Unhandled local event: %v", event)
}
} }
} }

View file

@ -0,0 +1,25 @@
// +build !ent
package consul
import (
"github.com/hashicorp/serf/serf"
)
type EnterpriseClient struct{}
func (c *Client) initEnterprise() error {
return nil
}
func (c *Client) startEnterprise() error {
return nil
}
func (c *Client) handleEnterpriseUserEvents(event serf.UserEvent) bool {
return false
}
func (c *Client) enterpriseStats() map[string]map[string]string {
return nil
}

View file

@ -0,0 +1,32 @@
// +build !ent
package consul
import (
"net"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/serf/serf"
)
type EnterpriseServer struct{}
func (s *Server) initEnterprise() error {
return nil
}
func (s *Server) startEnterprise() error {
return nil
}
func (s *Server) handleEnterpriseUserEvents(event serf.UserEvent) bool {
return false
}
func (s *Server) handleEnterpriseRPCConn(rtype pool.RPCType, conn net.Conn, isTLS bool) bool {
return false
}
func (s *Server) enterpriseStats() map[string]map[string]string {
return nil
}

View file

@ -115,9 +115,10 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) {
s.handleSnapshotConn(conn) s.handleSnapshotConn(conn)
default: default:
s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn)) if !s.handleEnterpriseRPCConn(typ, conn, isTLS) {
conn.Close() s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn))
return conn.Close()
}
} }
} }

View file

@ -208,6 +208,9 @@ type Server struct {
shutdown bool shutdown bool
shutdownCh chan struct{} shutdownCh chan struct{}
shutdownLock sync.Mutex shutdownLock sync.Mutex
// embedded struct to hold all the enterprise specific data
EnterpriseServer
} }
func NewServer(config *Config) (*Server, error) { func NewServer(config *Config) (*Server, error) {
@ -297,6 +300,12 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (*
shutdownCh: shutdownCh, shutdownCh: shutdownCh,
} }
// Initialize enterprise specific server functionality
if err := s.initEnterprise(); err != nil {
s.Shutdown()
return nil, err
}
// Initialize the stats fetcher that autopilot will use. // Initialize the stats fetcher that autopilot will use.
s.statsFetcher = NewStatsFetcher(logger, s.connPool, s.config.Datacenter) s.statsFetcher = NewStatsFetcher(logger, s.connPool, s.config.Datacenter)
@ -338,6 +347,12 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (*
return nil, fmt.Errorf("Failed to start Raft: %v", err) return nil, fmt.Errorf("Failed to start Raft: %v", err)
} }
// Start enterprise specific functionality
if err := s.startEnterprise(); err != nil {
s.Shutdown()
return nil, err
}
// Serf and dynamic bind ports // Serf and dynamic bind ports
// //
// The LAN serf cluster announces the port of the WAN serf cluster // The LAN serf cluster announces the port of the WAN serf cluster
@ -1019,6 +1034,17 @@ func (s *Server) Stats() map[string]map[string]string {
if s.serfWAN != nil { if s.serfWAN != nil {
stats["serf_wan"] = s.serfWAN.Stats() stats["serf_wan"] = s.serfWAN.Stats()
} }
for outerKey, outerValue := range s.enterpriseStats() {
if _, ok := stats[outerKey]; ok {
for innerKey, innerValue := range outerValue {
stats[outerKey][innerKey] = innerValue
}
} else {
stats[outerKey] = outerValue
}
}
return stats return stats
} }

View file

@ -198,7 +198,9 @@ func (s *Server) localEvent(event serf.UserEvent) {
s.config.UserEventHandler(event) s.config.UserEventHandler(event)
} }
default: default:
s.logger.Printf("[WARN] consul: Unhandled local event: %v", event) if !s.handleEnterpriseUserEvents(event) {
s.logger.Printf("[WARN] consul: Unhandled local event: %v", event)
}
} }
} }

View file

@ -0,0 +1,6 @@
// +build !ent
package agent
// enterpriseDelegate has no functions in OSS
type enterpriseDelegate interface{}

View file

@ -31,6 +31,15 @@ func (e MethodNotAllowedError) Error() string {
return fmt.Sprintf("method %s not allowed", e.Method) return fmt.Sprintf("method %s not allowed", e.Method)
} }
// BadRequestError should be returned by a handler when parameters or the payload are not valid
type BadRequestError struct {
Reason string
}
func (e BadRequestError) Error() string {
return fmt.Sprintf("Bad request: %s", e.Reason)
}
// HTTPServer provides an HTTP api for an agent. // HTTPServer provides an HTTP api for an agent.
type HTTPServer struct { type HTTPServer struct {
*http.Server *http.Server
@ -249,6 +258,11 @@ func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc {
return ok return ok
} }
isBadRequest := func(err error) bool {
_, ok := err.(BadRequestError)
return ok
}
addAllowHeader := func(methods []string) { addAllowHeader := func(methods []string) {
resp.Header().Add("Allow", strings.Join(methods, ",")) resp.Header().Add("Allow", strings.Join(methods, ","))
} }
@ -269,6 +283,9 @@ func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc {
addAllowHeader(err.(MethodNotAllowedError).Allow) addAllowHeader(err.(MethodNotAllowedError).Allow)
resp.WriteHeader(http.StatusMethodNotAllowed) // 405 resp.WriteHeader(http.StatusMethodNotAllowed) // 405
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())
case isBadRequest(err):
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, err.Error())
default: default:
resp.WriteHeader(http.StatusInternalServerError) resp.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(resp, err.Error()) fmt.Fprint(resp, err.Error())

View file

@ -0,0 +1,42 @@
package helpers
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
)
func LoadDataSource(data string, testStdin io.Reader) (string, error) {
var stdin io.Reader = os.Stdin
if testStdin != nil {
stdin = testStdin
}
// Handle empty quoted shell parameters
if len(data) == 0 {
return "", nil
}
switch data[0] {
case '@':
data, err := ioutil.ReadFile(data[1:])
if err != nil {
return "", fmt.Errorf("Failed to read file: %s", err)
} else {
return string(data), nil
}
case '-':
if len(data) > 1 {
return data, nil
}
var b bytes.Buffer
if _, err := io.Copy(&b, stdin); err != nil {
return "", fmt.Errorf("Failed to read stdin: %s", err)
}
return b.String(), nil
default:
return data, nil
}
}

View file

@ -1,16 +1,14 @@
package put package put
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/command/flags" "github.com/hashicorp/consul/command/flags"
"github.com/hashicorp/consul/command/helpers"
"github.com/mitchellh/cli" "github.com/mitchellh/cli"
) )
@ -173,11 +171,6 @@ func (c *cmd) Run(args []string) int {
} }
func (c *cmd) dataFromArgs(args []string) (string, string, error) { func (c *cmd) dataFromArgs(args []string) (string, string, error) {
var stdin io.Reader = os.Stdin
if c.testStdin != nil {
stdin = c.testStdin
}
switch len(args) { switch len(args) {
case 0: case 0:
return "", "", fmt.Errorf("Missing KEY argument") return "", "", fmt.Errorf("Missing KEY argument")
@ -189,30 +182,11 @@ func (c *cmd) dataFromArgs(args []string) (string, string, error) {
} }
key := args[0] key := args[0]
data := args[1] data, err := helpers.LoadDataSource(args[1], c.testStdin)
// Handle empty quoted shell parameters if err != nil {
if len(data) == 0 { return "", "", err
return key, "", nil } else {
}
switch data[0] {
case '@':
data, err := ioutil.ReadFile(data[1:])
if err != nil {
return "", "", fmt.Errorf("Failed to read file: %s", err)
}
return key, string(data), nil
case '-':
if len(data) > 1 {
return key, data, nil
}
var b bytes.Buffer
if _, err := io.Copy(&b, stdin); err != nil {
return "", "", fmt.Errorf("Failed to read stdin: %s", err)
}
return key, b.String(), nil
default:
return key, data, nil return key, data, nil
} }
} }