Merge pull request #4156 from hashicorp/enterprise-coexistence

Enterprise/Licensing Cleanup
This commit is contained in:
Matt Keeler 2018-06-05 10:50:32 -04:00 committed by GitHub
commit e043621dd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 189 additions and 36 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@
*.swp
*.test
.DS_Store
.fseventsd
.vagrant/
/pkg
Thumbs.db

View File

@ -73,6 +73,7 @@ type delegate interface {
SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io.Writer, replyFn structs.SnapshotReplyFn) error
Shutdown() error
Stats() map[string]map[string]string
enterpriseDelegate
}
// notifier is called after a successful JoinLAN.

View File

@ -72,6 +72,9 @@ type Client struct {
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
// embedded struct to hold all the enterprise specific data
EnterpriseClient
}
// NewClient is used to construct a new Consul client from the
@ -131,6 +134,11 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) {
shutdownCh: make(chan struct{}),
}
if err := c.initEnterprise(); err != nil {
c.Shutdown()
return nil, err
}
// Initialize the LAN Serf
c.serf, err = c.setupSerf(config.SerfLANConfig,
c.eventCh, serfLANSnapshot)
@ -147,6 +155,11 @@ func NewClientLogger(config *Config, logger *log.Logger) (*Client, error) {
// handlers depend on the router and the router depends on Serf.
go c.lanEventHandler()
if err := c.startEnterprise(); err != nil {
c.Shutdown()
return nil, err
}
return c, nil
}
@ -342,6 +355,17 @@ func (c *Client) Stats() map[string]map[string]string {
"serf_lan": c.serf.Stats(),
"runtime": runtimeStats(),
}
for outerKey, outerValue := range c.enterpriseStats() {
if _, ok := stats[outerKey]; ok {
for innerKey, innerValue := range outerValue {
stats[outerKey][innerKey] = innerValue
}
} else {
stats[outerKey] = outerValue
}
}
return stats
}

View File

@ -135,6 +135,8 @@ func (c *Client) localEvent(event serf.UserEvent) {
c.config.UserEventHandler(event)
}
default:
c.logger.Printf("[WARN] consul: Unhandled local event: %v", event)
if !c.handleEnterpriseUserEvents(event) {
c.logger.Printf("[WARN] consul: Unhandled local event: %v", event)
}
}
}

View File

@ -0,0 +1,25 @@
// +build !ent
package consul
import (
"github.com/hashicorp/serf/serf"
)
type EnterpriseClient struct{}
func (c *Client) initEnterprise() error {
return nil
}
func (c *Client) startEnterprise() error {
return nil
}
func (c *Client) handleEnterpriseUserEvents(event serf.UserEvent) bool {
return false
}
func (c *Client) enterpriseStats() map[string]map[string]string {
return nil
}

View File

@ -0,0 +1,32 @@
// +build !ent
package consul
import (
"net"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/serf/serf"
)
type EnterpriseServer struct{}
func (s *Server) initEnterprise() error {
return nil
}
func (s *Server) startEnterprise() error {
return nil
}
func (s *Server) handleEnterpriseUserEvents(event serf.UserEvent) bool {
return false
}
func (s *Server) handleEnterpriseRPCConn(rtype pool.RPCType, conn net.Conn, isTLS bool) bool {
return false
}
func (s *Server) enterpriseStats() map[string]map[string]string {
return nil
}

View File

@ -115,9 +115,10 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) {
s.handleSnapshotConn(conn)
default:
s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn))
conn.Close()
return
if !s.handleEnterpriseRPCConn(typ, conn, isTLS) {
s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn))
conn.Close()
}
}
}

View File

@ -208,6 +208,9 @@ type Server struct {
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
// embedded struct to hold all the enterprise specific data
EnterpriseServer
}
func NewServer(config *Config) (*Server, error) {
@ -297,6 +300,12 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (*
shutdownCh: shutdownCh,
}
// Initialize enterprise specific server functionality
if err := s.initEnterprise(); err != nil {
s.Shutdown()
return nil, err
}
// Initialize the stats fetcher that autopilot will use.
s.statsFetcher = NewStatsFetcher(logger, s.connPool, s.config.Datacenter)
@ -338,6 +347,12 @@ func NewServerLogger(config *Config, logger *log.Logger, tokens *token.Store) (*
return nil, fmt.Errorf("Failed to start Raft: %v", err)
}
// Start enterprise specific functionality
if err := s.startEnterprise(); err != nil {
s.Shutdown()
return nil, err
}
// Serf and dynamic bind ports
//
// The LAN serf cluster announces the port of the WAN serf cluster
@ -1019,6 +1034,17 @@ func (s *Server) Stats() map[string]map[string]string {
if s.serfWAN != nil {
stats["serf_wan"] = s.serfWAN.Stats()
}
for outerKey, outerValue := range s.enterpriseStats() {
if _, ok := stats[outerKey]; ok {
for innerKey, innerValue := range outerValue {
stats[outerKey][innerKey] = innerValue
}
} else {
stats[outerKey] = outerValue
}
}
return stats
}

View File

@ -198,7 +198,9 @@ func (s *Server) localEvent(event serf.UserEvent) {
s.config.UserEventHandler(event)
}
default:
s.logger.Printf("[WARN] consul: Unhandled local event: %v", event)
if !s.handleEnterpriseUserEvents(event) {
s.logger.Printf("[WARN] consul: Unhandled local event: %v", event)
}
}
}

View File

@ -0,0 +1,6 @@
// +build !ent
package agent
// enterpriseDelegate has no functions in OSS
type enterpriseDelegate interface{}

View File

@ -31,6 +31,15 @@ func (e MethodNotAllowedError) Error() string {
return fmt.Sprintf("method %s not allowed", e.Method)
}
// BadRequestError should be returned by a handler when parameters or the payload are not valid
type BadRequestError struct {
Reason string
}
func (e BadRequestError) Error() string {
return fmt.Sprintf("Bad request: %s", e.Reason)
}
// HTTPServer provides an HTTP api for an agent.
type HTTPServer struct {
*http.Server
@ -249,6 +258,11 @@ func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc {
return ok
}
isBadRequest := func(err error) bool {
_, ok := err.(BadRequestError)
return ok
}
addAllowHeader := func(methods []string) {
resp.Header().Add("Allow", strings.Join(methods, ","))
}
@ -269,6 +283,9 @@ func (s *HTTPServer) wrap(handler endpoint, methods []string) http.HandlerFunc {
addAllowHeader(err.(MethodNotAllowedError).Allow)
resp.WriteHeader(http.StatusMethodNotAllowed) // 405
fmt.Fprint(resp, err.Error())
case isBadRequest(err):
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprint(resp, err.Error())
default:
resp.WriteHeader(http.StatusInternalServerError)
fmt.Fprint(resp, err.Error())

View File

@ -0,0 +1,42 @@
package helpers
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
)
func LoadDataSource(data string, testStdin io.Reader) (string, error) {
var stdin io.Reader = os.Stdin
if testStdin != nil {
stdin = testStdin
}
// Handle empty quoted shell parameters
if len(data) == 0 {
return "", nil
}
switch data[0] {
case '@':
data, err := ioutil.ReadFile(data[1:])
if err != nil {
return "", fmt.Errorf("Failed to read file: %s", err)
} else {
return string(data), nil
}
case '-':
if len(data) > 1 {
return data, nil
}
var b bytes.Buffer
if _, err := io.Copy(&b, stdin); err != nil {
return "", fmt.Errorf("Failed to read stdin: %s", err)
}
return b.String(), nil
default:
return data, nil
}
}

View File

@ -1,16 +1,14 @@
package put
import (
"bytes"
"encoding/base64"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/command/flags"
"github.com/hashicorp/consul/command/helpers"
"github.com/mitchellh/cli"
)
@ -173,11 +171,6 @@ func (c *cmd) Run(args []string) int {
}
func (c *cmd) dataFromArgs(args []string) (string, string, error) {
var stdin io.Reader = os.Stdin
if c.testStdin != nil {
stdin = c.testStdin
}
switch len(args) {
case 0:
return "", "", fmt.Errorf("Missing KEY argument")
@ -189,30 +182,11 @@ func (c *cmd) dataFromArgs(args []string) (string, string, error) {
}
key := args[0]
data := args[1]
data, err := helpers.LoadDataSource(args[1], c.testStdin)
// Handle empty quoted shell parameters
if len(data) == 0 {
return key, "", nil
}
switch data[0] {
case '@':
data, err := ioutil.ReadFile(data[1:])
if err != nil {
return "", "", fmt.Errorf("Failed to read file: %s", err)
}
return key, string(data), nil
case '-':
if len(data) > 1 {
return key, data, nil
}
var b bytes.Buffer
if _, err := io.Copy(&b, stdin); err != nil {
return "", "", fmt.Errorf("Failed to read stdin: %s", err)
}
return key, b.String(), nil
default:
if err != nil {
return "", "", err
} else {
return key, data, nil
}
}