Allow peering endpoints to bypass verify_incoming.
This commit is contained in:
parent
065e538de3
commit
58f15db4c4
|
@ -43,6 +43,7 @@ import (
|
||||||
"github.com/hashicorp/consul/agent/dns"
|
"github.com/hashicorp/consul/agent/dns"
|
||||||
external "github.com/hashicorp/consul/agent/grpc-external"
|
external "github.com/hashicorp/consul/agent/grpc-external"
|
||||||
grpcDNS "github.com/hashicorp/consul/agent/grpc-external/services/dns"
|
grpcDNS "github.com/hashicorp/consul/agent/grpc-external/services/dns"
|
||||||
|
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/hcp/scada"
|
"github.com/hashicorp/consul/agent/hcp/scada"
|
||||||
libscada "github.com/hashicorp/consul/agent/hcp/scada"
|
libscada "github.com/hashicorp/consul/agent/hcp/scada"
|
||||||
"github.com/hashicorp/consul/agent/local"
|
"github.com/hashicorp/consul/agent/local"
|
||||||
|
@ -563,6 +564,7 @@ func (a *Agent) Start(ctx context.Context) error {
|
||||||
a.externalGRPCServer = external.NewServer(
|
a.externalGRPCServer = external.NewServer(
|
||||||
a.logger.Named("grpc.external"),
|
a.logger.Named("grpc.external"),
|
||||||
metrics.Default(),
|
metrics.Default(),
|
||||||
|
a.tlsConfigurator,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := a.startLicenseManager(ctx); err != nil {
|
if err := a.startLicenseManager(ctx); err != nil {
|
||||||
|
@ -855,7 +857,7 @@ func (a *Agent) listenAndServeGRPC() error {
|
||||||
|
|
||||||
// Attempt to spawn listeners
|
// Attempt to spawn listeners
|
||||||
var listeners []net.Listener
|
var listeners []net.Listener
|
||||||
start := func(port_name string, addrs []net.Addr, tlsConf *tls.Config) error {
|
start := func(port_name string, addrs []net.Addr, protocol middleware.Protocol) error {
|
||||||
if len(addrs) < 1 {
|
if len(addrs) < 1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -865,10 +867,7 @@ func (a *Agent) listenAndServeGRPC() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range ln {
|
for i := range ln {
|
||||||
// Wrap with TLS, if provided.
|
ln[i] = middleware.LabelledListener{Listener: ln[i], Protocol: protocol}
|
||||||
if tlsConf != nil {
|
|
||||||
ln[i] = tls.NewListener(ln[i], tlsConf)
|
|
||||||
}
|
|
||||||
listeners = append(listeners, ln[i])
|
listeners = append(listeners, ln[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -892,19 +891,19 @@ func (a *Agent) listenAndServeGRPC() error {
|
||||||
// TODO: Simplify this block to only spawn plain-text after 1.14 when deprecated TLS support is removed.
|
// TODO: Simplify this block to only spawn plain-text after 1.14 when deprecated TLS support is removed.
|
||||||
if a.config.GRPCPort > 0 {
|
if a.config.GRPCPort > 0 {
|
||||||
// Only allow the grpc port to spawn TLS connections if the other grpc_tls port is NOT defined.
|
// Only allow the grpc port to spawn TLS connections if the other grpc_tls port is NOT defined.
|
||||||
var tlsConf *tls.Config = nil
|
protocol := middleware.ProtocolPlaintext
|
||||||
if a.config.GRPCTLSPort <= 0 && a.tlsConfigurator.GRPCServerUseTLS() {
|
if a.config.GRPCTLSPort <= 0 && a.tlsConfigurator.GRPCServerUseTLS() {
|
||||||
a.logger.Warn("deprecated gRPC TLS configuration detected. Consider using `ports.grpc_tls` instead")
|
a.logger.Warn("deprecated gRPC TLS configuration detected. Consider using `ports.grpc_tls` instead")
|
||||||
tlsConf = a.tlsConfigurator.IncomingGRPCConfig()
|
protocol = middleware.ProtocolTLS
|
||||||
}
|
}
|
||||||
if err := start("grpc", a.config.GRPCAddrs, tlsConf); err != nil {
|
if err := start("grpc", a.config.GRPCAddrs, protocol); err != nil {
|
||||||
closeListeners(listeners)
|
closeListeners(listeners)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Only allow grpc_tls to spawn with a TLS listener.
|
// Only allow grpc_tls to spawn with a TLS listener.
|
||||||
if a.config.GRPCTLSPort > 0 {
|
if a.config.GRPCTLSPort > 0 {
|
||||||
if err := start("grpc_tls", a.config.GRPCTLSAddrs, a.tlsConfigurator.IncomingGRPCConfig()); err != nil {
|
if err := start("grpc_tls", a.config.GRPCTLSAddrs, middleware.ProtocolTLS); err != nil {
|
||||||
closeListeners(listeners)
|
closeListeners(listeners)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package consul
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -31,6 +30,7 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/consul/agent/connect"
|
"github.com/hashicorp/consul/agent/connect"
|
||||||
external "github.com/hashicorp/consul/agent/grpc-external"
|
external "github.com/hashicorp/consul/agent/grpc-external"
|
||||||
|
grpcmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/metadata"
|
"github.com/hashicorp/consul/agent/metadata"
|
||||||
"github.com/hashicorp/consul/agent/rpc/middleware"
|
"github.com/hashicorp/consul/agent/rpc/middleware"
|
||||||
"github.com/hashicorp/consul/agent/structs"
|
"github.com/hashicorp/consul/agent/structs"
|
||||||
|
@ -258,7 +258,9 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S
|
||||||
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", grpcPort))
|
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", grpcPort))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
protocol := grpcmiddleware.ProtocolPlaintext
|
||||||
if grpcPort == srv.config.GRPCTLSPort || deps.TLSConfigurator.GRPCServerUseTLS() {
|
if grpcPort == srv.config.GRPCTLSPort || deps.TLSConfigurator.GRPCServerUseTLS() {
|
||||||
|
protocol = grpcmiddleware.ProtocolTLS
|
||||||
// Set the internally managed server certificate. The cert manager is hooked to the Agent, so we need to bypass that here.
|
// Set the internally managed server certificate. The cert manager is hooked to the Agent, so we need to bypass that here.
|
||||||
if srv.config.PeeringEnabled && srv.config.ConnectEnabled {
|
if srv.config.PeeringEnabled && srv.config.ConnectEnabled {
|
||||||
key, _ := srv.config.CAConfig.Config["PrivateKey"].(string)
|
key, _ := srv.config.CAConfig.Config["PrivateKey"].(string)
|
||||||
|
@ -273,9 +275,8 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the listener with TLS.
|
|
||||||
ln = tls.NewListener(ln, deps.TLSConfigurator.IncomingGRPCConfig())
|
|
||||||
}
|
}
|
||||||
|
ln = grpcmiddleware.LabelledListener{Listener: ln, Protocol: protocol}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_ = srv.externalGRPCServer.Serve(ln)
|
_ = srv.externalGRPCServer.Serve(ln)
|
||||||
|
@ -329,7 +330,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
|
||||||
oldNotify()
|
oldNotify()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil)
|
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator)
|
||||||
srv, err := NewServer(c, deps, grpcServer)
|
srv, err := NewServer(c, deps, grpcServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -7,9 +7,11 @@ import (
|
||||||
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
|
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -21,25 +23,34 @@ var (
|
||||||
|
|
||||||
// NewServer constructs a gRPC server for the external gRPC port, to which
|
// NewServer constructs a gRPC server for the external gRPC port, to which
|
||||||
// handlers can be registered.
|
// handlers can be registered.
|
||||||
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics) *grpc.Server {
|
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server {
|
||||||
if metricsObj == nil {
|
if metricsObj == nil {
|
||||||
metricsObj = metrics.Default()
|
metricsObj = metrics.Default()
|
||||||
}
|
}
|
||||||
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
|
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
|
||||||
|
|
||||||
|
unaryInterceptors := []grpc.UnaryServerInterceptor{
|
||||||
|
// Add middlware interceptors to recover in case of panics.
|
||||||
|
recovery.UnaryServerInterceptor(recoveryOpts...),
|
||||||
|
}
|
||||||
|
streamInterceptors := []grpc.StreamServerInterceptor{
|
||||||
|
// Add middlware interceptors to recover in case of panics.
|
||||||
|
recovery.StreamServerInterceptor(recoveryOpts...),
|
||||||
|
agentmiddleware.NewActiveStreamCounter(metricsObj, metricsLabels).Intercept,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tls != nil {
|
||||||
|
// Attach TLS middleware if TLS is provided.
|
||||||
|
authInterceptor := agentmiddleware.AuthInterceptor{TLS: tls, Logger: logger}
|
||||||
|
unaryInterceptors = append(unaryInterceptors, authInterceptor.InterceptUnary)
|
||||||
|
streamInterceptors = append(streamInterceptors, authInterceptor.InterceptStream)
|
||||||
|
}
|
||||||
opts := []grpc.ServerOption{
|
opts := []grpc.ServerOption{
|
||||||
grpc.MaxConcurrentStreams(2048),
|
grpc.MaxConcurrentStreams(2048),
|
||||||
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
|
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
|
||||||
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
|
||||||
middleware.WithUnaryServerChain(
|
middleware.WithUnaryServerChain(unaryInterceptors...),
|
||||||
// Add middlware interceptors to recover in case of panics.
|
middleware.WithStreamServerChain(streamInterceptors...),
|
||||||
recovery.UnaryServerInterceptor(recoveryOpts...),
|
|
||||||
),
|
|
||||||
middleware.WithStreamServerChain(
|
|
||||||
// Add middlware interceptors to recover in case of panics.
|
|
||||||
recovery.StreamServerInterceptor(recoveryOpts...),
|
|
||||||
agentmiddleware.NewActiveStreamCounter(metricsObj, metricsLabels).Intercept,
|
|
||||||
),
|
|
||||||
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
||||||
// This must be less than the keealive.ClientParameters Time setting, otherwise
|
// This must be less than the keealive.ClientParameters Time setting, otherwise
|
||||||
// the server will disconnect the client for sending too many keepalive pings.
|
// the server will disconnect the client for sending too many keepalive pings.
|
||||||
|
@ -47,5 +58,13 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics) *grpc
|
||||||
MinTime: 15 * time.Second,
|
MinTime: 15 * time.Second,
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tls != nil {
|
||||||
|
// Attach TLS credentials, if provided.
|
||||||
|
tlsCreds := agentmiddleware.NewOptionalTransportCredentials(
|
||||||
|
credentials.NewTLS(tls.IncomingGRPCConfig()),
|
||||||
|
logger)
|
||||||
|
opts = append(opts, grpc.Creds(tlsCreds))
|
||||||
|
}
|
||||||
return grpc.NewServer(opts...)
|
return grpc.NewServer(opts...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
|
|
||||||
|
grpcmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil"
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil"
|
||||||
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
|
||||||
"github.com/hashicorp/consul/proto/prototest"
|
"github.com/hashicorp/consul/proto/prototest"
|
||||||
|
@ -22,12 +23,13 @@ import (
|
||||||
func TestServer_EmitsStats(t *testing.T) {
|
func TestServer_EmitsStats(t *testing.T) {
|
||||||
sink, metricsObj := testutil.NewFakeSink(t)
|
sink, metricsObj := testutil.NewFakeSink(t)
|
||||||
|
|
||||||
srv := NewServer(hclog.Default(), metricsObj)
|
srv := NewServer(hclog.Default(), metricsObj, nil)
|
||||||
|
|
||||||
testservice.RegisterSimpleServer(srv, &testservice.Simple{})
|
testservice.RegisterSimpleServer(srv, &testservice.Simple{})
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
lis = grpcmiddleware.LabelledListener{Listener: lis, Protocol: grpcmiddleware.ProtocolPlaintext}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
t.Cleanup(cancel)
|
t.Cleanup(cancel)
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/consul/tlsutil"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const AllowedPeerEndpointPrefix = "/hashicorp.consul.internal.peerstream.PeerStreamService/"
|
||||||
|
|
||||||
|
// AuthInterceptor provides gRPC interceptors for restricting endpoint access based
|
||||||
|
// on SNI. If the connection is plaintext, this filter will not activate, and the
|
||||||
|
// connection will be allowed to proceed.
|
||||||
|
type AuthInterceptor struct {
|
||||||
|
TLS *tlsutil.Configurator
|
||||||
|
Logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterceptUnary prevents non-streaming gRPC calls from calling certain endpoints,
|
||||||
|
// based on the SNI information.
|
||||||
|
func (a *AuthInterceptor) InterceptUnary(
|
||||||
|
ctx context.Context,
|
||||||
|
req interface{},
|
||||||
|
info *grpc.UnaryServerInfo,
|
||||||
|
handler grpc.UnaryHandler,
|
||||||
|
) (interface{}, error) {
|
||||||
|
p, ok := peer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unable to fetch peer info from grpc context")
|
||||||
|
}
|
||||||
|
err := restrictPeeringEndpoints(p.AuthInfo, a.TLS.PeeringServerName(), info.FullMethod)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InterceptUnary prevents streaming gRPC calls from calling certain endpoints,
|
||||||
|
// based on the SNI information.
|
||||||
|
func (a *AuthInterceptor) InterceptStream(
|
||||||
|
srv interface{},
|
||||||
|
ss grpc.ServerStream,
|
||||||
|
info *grpc.StreamServerInfo,
|
||||||
|
handler grpc.StreamHandler,
|
||||||
|
) error {
|
||||||
|
p, ok := peer.FromContext(ss.Context())
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unable to fetch peer info from grpc context")
|
||||||
|
}
|
||||||
|
err := restrictPeeringEndpoints(p.AuthInfo, a.TLS.PeeringServerName(), info.FullMethod)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
|
||||||
|
// restrictPeeringEndpoints will return an error if a peering TLS connection attempts to call
|
||||||
|
// a non-peering endpoint. This is necessary, because the peer streaming workflow does not
|
||||||
|
// present a mutual TLS certificate, and is allowed to bypass the `tls.grpc.verify_incoming`
|
||||||
|
// check as a special case. See the `tlsutil.Configurator` for this bypass.
|
||||||
|
func restrictPeeringEndpoints(authInfo credentials.AuthInfo, peeringSNI string, endpoint string) error {
|
||||||
|
// This indicates a plaintext connection.
|
||||||
|
if authInfo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Otherwise attempt to check the AuthInfo for TLS credentials.
|
||||||
|
tlsAuth, ok := authInfo.(credentials.TLSInfo)
|
||||||
|
if !ok {
|
||||||
|
return status.Error(codes.Unauthenticated, "invalid transport credentials")
|
||||||
|
}
|
||||||
|
if tlsAuth.State.ServerName == peeringSNI {
|
||||||
|
// Prevent any calls, except those in the PeerStreamService
|
||||||
|
if !strings.HasPrefix(endpoint, AllowedPeerEndpointPrefix) {
|
||||||
|
return status.Error(codes.PermissionDenied, "invalid permissions to the specified endpoint")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
type invalidAuthInfo struct{}
|
||||||
|
|
||||||
|
func (i invalidAuthInfo) AuthType() string {
|
||||||
|
return "invalid."
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGRPCMiddleware_restrictPeeringEndpoints(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
authInfo credentials.AuthInfo
|
||||||
|
peeringSNI string
|
||||||
|
endpoint string
|
||||||
|
expectErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "plaintext_always_allowed",
|
||||||
|
authInfo: nil,
|
||||||
|
peeringSNI: "expected-server-sni",
|
||||||
|
endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deny_invalid_credentials",
|
||||||
|
authInfo: invalidAuthInfo{},
|
||||||
|
expectErr: "invalid transport credentials",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peering_sni_with_invalid_endpoint",
|
||||||
|
authInfo: credentials.TLSInfo{
|
||||||
|
State: tls.ConnectionState{
|
||||||
|
ServerName: "peering-sni",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peeringSNI: "peering-sni",
|
||||||
|
endpoint: "/some-invalid-endpoint",
|
||||||
|
expectErr: "invalid permissions to the specified endpoint",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "peering_sni_with_valid_endpoint",
|
||||||
|
authInfo: credentials.TLSInfo{
|
||||||
|
State: tls.ConnectionState{
|
||||||
|
ServerName: "peering-sni",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peeringSNI: "peering-sni",
|
||||||
|
endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non_peering_sni_always_allowed",
|
||||||
|
authInfo: credentials.TLSInfo{
|
||||||
|
State: tls.ConnectionState{
|
||||||
|
ServerName: "non-peering-sni",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
peeringSNI: "peering-sni",
|
||||||
|
endpoint: "/some-non-peering-endpoint",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := restrictPeeringEndpoints(tc.authInfo, tc.peeringSNI, tc.endpoint)
|
||||||
|
if tc.expectErr == "" {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Contains(t, err.Error(), tc.expectErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ net.Listener = (*LabelledListener)(nil)
|
||||||
|
var _ net.Conn = (*LabelledConn)(nil)
|
||||||
|
|
||||||
|
type Protocol int
|
||||||
|
|
||||||
|
var (
|
||||||
|
ProtocolPlaintext Protocol = 0
|
||||||
|
ProtocolTLS Protocol = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// LabelledListener wraps a listener and attaches pre-specified
|
||||||
|
// fields to each spawned connection.
|
||||||
|
type LabelledListener struct {
|
||||||
|
net.Listener
|
||||||
|
Protocol Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LabelledListener) Accept() (net.Conn, error) {
|
||||||
|
conn, err := l.Listener.Accept()
|
||||||
|
if conn != nil {
|
||||||
|
conn = LabelledConn{conn, l.Protocol}
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// LabelledConn wraps a connection and provides extra metadata fields.
|
||||||
|
type LabelledConn struct {
|
||||||
|
net.Conn
|
||||||
|
protocol Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ credentials.TransportCredentials = (*optionalTransportCredentials)(nil)
|
||||||
|
|
||||||
|
// optionalTransportCredentials provides a way to selectively perform a TLS handshake
|
||||||
|
// based on metadata extracted from the underlying connection object.
|
||||||
|
type optionalTransportCredentials struct {
|
||||||
|
credentials.TransportCredentials
|
||||||
|
logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger Logger) credentials.TransportCredentials {
|
||||||
|
return &optionalTransportCredentials{creds, logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerHandshake will attempt to detect the underlying connection protocol (TLS or Plaintext)
|
||||||
|
// based on metadata attached to the underlying connection. If TLS is detected, then a handshake
|
||||||
|
// will be performed, and the corresponding AuthInfo will be attached to the gRPC context.
|
||||||
|
// For plaintext connections, this is effectively a no-op, since there is no TLS info to attach.
|
||||||
|
// If the underlying connection is not a LabelledConn with a valid protocol, then this method will
|
||||||
|
// panic and prevent the gRPC connection from successfully progressing further.
|
||||||
|
func (tc *optionalTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
// This should always be a LabelledConn, so no check is necessary.
|
||||||
|
nc := conn.(LabelledConn)
|
||||||
|
switch nc.protocol {
|
||||||
|
case ProtocolPlaintext:
|
||||||
|
// This originated from a plaintext listener, so do not use TLS auth.
|
||||||
|
return nc, nil, nil
|
||||||
|
case ProtocolTLS:
|
||||||
|
// This originated from a TLS listener, so it should have a full handshake performed.
|
||||||
|
c, ai, err := tc.TransportCredentials.ServerHandshake(conn)
|
||||||
|
if err == nil && ai == nil {
|
||||||
|
// This should not be possible, but ensure that it's non-nil for safety.
|
||||||
|
return nil, nil, fmt.Errorf("missing auth info after handshake")
|
||||||
|
}
|
||||||
|
return c, ai, err
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("invalid protocol for grpc connection")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,84 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTransportCredentials struct {
|
||||||
|
credentials.TransportCredentials
|
||||||
|
callback func(conn net.Conn) (net.Conn, credentials.AuthInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
return f.callback(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGRPCMiddleware_optionalTransportCredentials_ServerHandshake(t *testing.T) {
|
||||||
|
|
||||||
|
plainConn := LabelledConn{protocol: ProtocolPlaintext}
|
||||||
|
tlsConn := LabelledConn{protocol: ProtocolTLS}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
conn net.Conn
|
||||||
|
callback func(conn net.Conn) (net.Conn, credentials.AuthInfo, error)
|
||||||
|
expectConn net.Conn
|
||||||
|
expectAuthInfo credentials.AuthInfo
|
||||||
|
expectErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "plaintext_noop",
|
||||||
|
conn: plainConn,
|
||||||
|
expectConn: plainConn,
|
||||||
|
expectAuthInfo: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls_with_missing_auth",
|
||||||
|
conn: tlsConn,
|
||||||
|
callback: func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
return conn, nil, nil
|
||||||
|
},
|
||||||
|
expectConn: nil,
|
||||||
|
expectAuthInfo: nil,
|
||||||
|
expectErr: "missing auth info after handshake",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls_success",
|
||||||
|
conn: tlsConn,
|
||||||
|
callback: func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
return conn, credentials.TLSInfo{}, nil
|
||||||
|
},
|
||||||
|
expectConn: tlsConn,
|
||||||
|
expectAuthInfo: credentials.TLSInfo{},
|
||||||
|
expectErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_protocol",
|
||||||
|
conn: LabelledConn{protocol: -1},
|
||||||
|
expectConn: nil,
|
||||||
|
expectAuthInfo: nil,
|
||||||
|
expectErr: "invalid protocol for grpc connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
creds := optionalTransportCredentials{
|
||||||
|
TransportCredentials: fakeTransportCredentials{
|
||||||
|
callback: tc.callback,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, authInfo, err := creds.ServerHandshake(tc.conn)
|
||||||
|
require.Equal(t, tc.expectConn, conn)
|
||||||
|
require.Equal(t, tc.expectAuthInfo, authInfo)
|
||||||
|
if tc.expectErr == "" {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tc.expectErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,7 +2,6 @@ package peering_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -29,6 +28,7 @@ import (
|
||||||
"github.com/hashicorp/consul/agent/grpc-external/limiter"
|
"github.com/hashicorp/consul/agent/grpc-external/limiter"
|
||||||
grpc "github.com/hashicorp/consul/agent/grpc-internal"
|
grpc "github.com/hashicorp/consul/agent/grpc-internal"
|
||||||
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
|
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
|
||||||
|
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
|
||||||
"github.com/hashicorp/consul/agent/pool"
|
"github.com/hashicorp/consul/agent/pool"
|
||||||
"github.com/hashicorp/consul/agent/router"
|
"github.com/hashicorp/consul/agent/router"
|
||||||
"github.com/hashicorp/consul/agent/rpc/middleware"
|
"github.com/hashicorp/consul/agent/rpc/middleware"
|
||||||
|
@ -1546,7 +1546,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
|
||||||
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
|
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
|
||||||
|
|
||||||
deps := newDefaultDeps(t, conf)
|
deps := newDefaultDeps(t, conf)
|
||||||
externalGRPCServer := external.NewServer(deps.Logger, nil)
|
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator)
|
||||||
|
|
||||||
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
server, err := consul.NewServer(conf, deps, externalGRPCServer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -1563,8 +1563,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", grpcAddr)
|
ln, err := net.Listen("tcp", grpcAddr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
ln = agentmiddleware.LabelledListener{Listener: ln, Protocol: agentmiddleware.ProtocolTLS}
|
||||||
ln = tls.NewListener(ln, deps.TLSConfigurator.IncomingGRPCConfig())
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_ = externalGRPCServer.Serve(ln)
|
_ = externalGRPCServer.Serve(ln)
|
||||||
|
|
|
@ -770,8 +770,16 @@ func (c *Configurator) IncomingGRPCConfig() *tls.Config {
|
||||||
c.base.GRPC,
|
c.base.GRPC,
|
||||||
c.base.GRPC.VerifyIncoming,
|
c.base.GRPC.VerifyIncoming,
|
||||||
)
|
)
|
||||||
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
return c.IncomingGRPCConfig(), nil
|
conf := c.IncomingGRPCConfig()
|
||||||
|
// Do not enforce mutualTLS for peering SNI entries. This is necessary, because
|
||||||
|
// there is no way to specify an mTLS cert when establishing a peering connection.
|
||||||
|
// This bypass is only safe because the `grpc-middleware.AuthInterceptor` explicitly
|
||||||
|
// restricts the list of endpoints that can be called when peering SNI is present.
|
||||||
|
if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName {
|
||||||
|
conf.ClientAuth = tls.NoClientCert
|
||||||
|
}
|
||||||
|
return conf, nil
|
||||||
}
|
}
|
||||||
config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName {
|
if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName {
|
||||||
|
@ -950,6 +958,12 @@ func (c *Configurator) AutoEncryptCert() *x509.Certificate {
|
||||||
return cert
|
return cert
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Configurator) PeeringServerName() string {
|
||||||
|
c.lock.RLock()
|
||||||
|
defer c.lock.RUnlock()
|
||||||
|
return c.autoTLS.peeringServerName
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Configurator) log(name string) {
|
func (c *Configurator) log(name string) {
|
||||||
if c.logger != nil && c.logger.IsTrace() {
|
if c.logger != nil && c.logger.IsTrace() {
|
||||||
c.logger.Trace(name, "version", atomic.LoadUint64(&c.version))
|
c.logger.Trace(name, "version", atomic.LoadUint64(&c.version))
|
||||||
|
|
Loading…
Reference in New Issue