Allow peering endpoints to bypass verify_incoming.

This commit is contained in:
Derek Menteer 2022-10-28 15:34:41 -05:00 committed by Derek Menteer
parent 065e538de3
commit 58f15db4c4
10 changed files with 390 additions and 30 deletions

View File

@ -43,6 +43,7 @@ import (
"github.com/hashicorp/consul/agent/dns"
external "github.com/hashicorp/consul/agent/grpc-external"
grpcDNS "github.com/hashicorp/consul/agent/grpc-external/services/dns"
middleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/hcp/scada"
libscada "github.com/hashicorp/consul/agent/hcp/scada"
"github.com/hashicorp/consul/agent/local"
@ -563,6 +564,7 @@ func (a *Agent) Start(ctx context.Context) error {
a.externalGRPCServer = external.NewServer(
a.logger.Named("grpc.external"),
metrics.Default(),
a.tlsConfigurator,
)
if err := a.startLicenseManager(ctx); err != nil {
@ -855,7 +857,7 @@ func (a *Agent) listenAndServeGRPC() error {
// Attempt to spawn listeners
var listeners []net.Listener
start := func(port_name string, addrs []net.Addr, tlsConf *tls.Config) error {
start := func(port_name string, addrs []net.Addr, protocol middleware.Protocol) error {
if len(addrs) < 1 {
return nil
}
@ -865,10 +867,7 @@ func (a *Agent) listenAndServeGRPC() error {
return err
}
for i := range ln {
// Wrap with TLS, if provided.
if tlsConf != nil {
ln[i] = tls.NewListener(ln[i], tlsConf)
}
ln[i] = middleware.LabelledListener{Listener: ln[i], Protocol: protocol}
listeners = append(listeners, ln[i])
}
@ -892,19 +891,19 @@ func (a *Agent) listenAndServeGRPC() error {
// TODO: Simplify this block to only spawn plain-text after 1.14 when deprecated TLS support is removed.
if a.config.GRPCPort > 0 {
// Only allow the grpc port to spawn TLS connections if the other grpc_tls port is NOT defined.
var tlsConf *tls.Config = nil
protocol := middleware.ProtocolPlaintext
if a.config.GRPCTLSPort <= 0 && a.tlsConfigurator.GRPCServerUseTLS() {
a.logger.Warn("deprecated gRPC TLS configuration detected. Consider using `ports.grpc_tls` instead")
tlsConf = a.tlsConfigurator.IncomingGRPCConfig()
protocol = middleware.ProtocolTLS
}
if err := start("grpc", a.config.GRPCAddrs, tlsConf); err != nil {
if err := start("grpc", a.config.GRPCAddrs, protocol); err != nil {
closeListeners(listeners)
return err
}
}
// Only allow grpc_tls to spawn with a TLS listener.
if a.config.GRPCTLSPort > 0 {
if err := start("grpc_tls", a.config.GRPCTLSAddrs, a.tlsConfigurator.IncomingGRPCConfig()); err != nil {
if err := start("grpc_tls", a.config.GRPCTLSAddrs, middleware.ProtocolTLS); err != nil {
closeListeners(listeners)
return err
}

View File

@ -2,7 +2,6 @@ package consul
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
@ -31,6 +30,7 @@ import (
"github.com/hashicorp/consul/agent/connect"
external "github.com/hashicorp/consul/agent/grpc-external"
grpcmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/rpc/middleware"
"github.com/hashicorp/consul/agent/structs"
@ -258,7 +258,9 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", grpcPort))
require.NoError(t, err)
protocol := grpcmiddleware.ProtocolPlaintext
if grpcPort == srv.config.GRPCTLSPort || deps.TLSConfigurator.GRPCServerUseTLS() {
protocol = grpcmiddleware.ProtocolTLS
// Set the internally managed server certificate. The cert manager is hooked to the Agent, so we need to bypass that here.
if srv.config.PeeringEnabled && srv.config.ConnectEnabled {
key, _ := srv.config.CAConfig.Config["PrivateKey"].(string)
@ -273,9 +275,8 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S
}
}
// Wrap the listener with TLS.
ln = tls.NewListener(ln, deps.TLSConfigurator.IncomingGRPCConfig())
}
ln = grpcmiddleware.LabelledListener{Listener: ln, Protocol: protocol}
go func() {
_ = srv.externalGRPCServer.Serve(ln)
@ -329,7 +330,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) {
oldNotify()
}
}
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil)
grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator)
srv, err := NewServer(c, deps, grpcServer)
if err != nil {
return nil, err

View File

@ -7,9 +7,11 @@ import (
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/tlsutil"
)
var (
@ -21,25 +23,34 @@ var (
// NewServer constructs a gRPC server for the external gRPC port, to which
// handlers can be registered.
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics) *grpc.Server {
func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server {
if metricsObj == nil {
metricsObj = metrics.Default()
}
recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger)
unaryInterceptors := []grpc.UnaryServerInterceptor{
// Add middlware interceptors to recover in case of panics.
recovery.UnaryServerInterceptor(recoveryOpts...),
}
streamInterceptors := []grpc.StreamServerInterceptor{
// Add middlware interceptors to recover in case of panics.
recovery.StreamServerInterceptor(recoveryOpts...),
agentmiddleware.NewActiveStreamCounter(metricsObj, metricsLabels).Intercept,
}
if tls != nil {
// Attach TLS middleware if TLS is provided.
authInterceptor := agentmiddleware.AuthInterceptor{TLS: tls, Logger: logger}
unaryInterceptors = append(unaryInterceptors, authInterceptor.InterceptUnary)
streamInterceptors = append(streamInterceptors, authInterceptor.InterceptStream)
}
opts := []grpc.ServerOption{
grpc.MaxConcurrentStreams(2048),
grpc.MaxRecvMsgSize(50 * 1024 * 1024),
grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)),
middleware.WithUnaryServerChain(
// Add middlware interceptors to recover in case of panics.
recovery.UnaryServerInterceptor(recoveryOpts...),
),
middleware.WithStreamServerChain(
// Add middlware interceptors to recover in case of panics.
recovery.StreamServerInterceptor(recoveryOpts...),
agentmiddleware.NewActiveStreamCounter(metricsObj, metricsLabels).Intercept,
),
middleware.WithUnaryServerChain(unaryInterceptors...),
middleware.WithStreamServerChain(streamInterceptors...),
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
// This must be less than the keealive.ClientParameters Time setting, otherwise
// the server will disconnect the client for sending too many keepalive pings.
@ -47,5 +58,13 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics) *grpc
MinTime: 15 * time.Second,
}),
}
if tls != nil {
// Attach TLS credentials, if provided.
tlsCreds := agentmiddleware.NewOptionalTransportCredentials(
credentials.NewTLS(tls.IncomingGRPCConfig()),
logger)
opts = append(opts, grpc.Creds(tlsCreds))
}
return grpc.NewServer(opts...)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-hclog"
grpcmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/proto/prototest"
@ -22,12 +23,13 @@ import (
func TestServer_EmitsStats(t *testing.T) {
sink, metricsObj := testutil.NewFakeSink(t)
srv := NewServer(hclog.Default(), metricsObj)
srv := NewServer(hclog.Default(), metricsObj, nil)
testservice.RegisterSimpleServer(srv, &testservice.Simple{})
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
lis = grpcmiddleware.LabelledListener{Listener: lis, Protocol: grpcmiddleware.ProtocolPlaintext}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

View File

@ -0,0 +1,85 @@
package middleware
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/consul/tlsutil"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
const AllowedPeerEndpointPrefix = "/hashicorp.consul.internal.peerstream.PeerStreamService/"
// AuthInterceptor provides gRPC interceptors for restricting endpoint access based
// on SNI. If the connection is plaintext, this filter will not activate, and the
// connection will be allowed to proceed.
type AuthInterceptor struct {
TLS *tlsutil.Configurator
Logger Logger
}
// InterceptUnary prevents non-streaming gRPC calls from calling certain endpoints,
// based on the SNI information.
func (a *AuthInterceptor) InterceptUnary(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("unable to fetch peer info from grpc context")
}
err := restrictPeeringEndpoints(p.AuthInfo, a.TLS.PeeringServerName(), info.FullMethod)
if err != nil {
return nil, err
}
return handler(ctx, req)
}
// InterceptUnary prevents streaming gRPC calls from calling certain endpoints,
// based on the SNI information.
func (a *AuthInterceptor) InterceptStream(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
p, ok := peer.FromContext(ss.Context())
if !ok {
return fmt.Errorf("unable to fetch peer info from grpc context")
}
err := restrictPeeringEndpoints(p.AuthInfo, a.TLS.PeeringServerName(), info.FullMethod)
if err != nil {
return err
}
return handler(srv, ss)
}
// restrictPeeringEndpoints will return an error if a peering TLS connection attempts to call
// a non-peering endpoint. This is necessary, because the peer streaming workflow does not
// present a mutual TLS certificate, and is allowed to bypass the `tls.grpc.verify_incoming`
// check as a special case. See the `tlsutil.Configurator` for this bypass.
func restrictPeeringEndpoints(authInfo credentials.AuthInfo, peeringSNI string, endpoint string) error {
// This indicates a plaintext connection.
if authInfo == nil {
return nil
}
// Otherwise attempt to check the AuthInfo for TLS credentials.
tlsAuth, ok := authInfo.(credentials.TLSInfo)
if !ok {
return status.Error(codes.Unauthenticated, "invalid transport credentials")
}
if tlsAuth.State.ServerName == peeringSNI {
// Prevent any calls, except those in the PeerStreamService
if !strings.HasPrefix(endpoint, AllowedPeerEndpointPrefix) {
return status.Error(codes.PermissionDenied, "invalid permissions to the specified endpoint")
}
}
return nil
}

View File

@ -0,0 +1,79 @@
package middleware
import (
"crypto/tls"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/credentials"
)
type invalidAuthInfo struct{}
func (i invalidAuthInfo) AuthType() string {
return "invalid."
}
func TestGRPCMiddleware_restrictPeeringEndpoints(t *testing.T) {
tests := []struct {
name string
authInfo credentials.AuthInfo
peeringSNI string
endpoint string
expectErr string
}{
{
name: "plaintext_always_allowed",
authInfo: nil,
peeringSNI: "expected-server-sni",
endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint",
},
{
name: "deny_invalid_credentials",
authInfo: invalidAuthInfo{},
expectErr: "invalid transport credentials",
},
{
name: "peering_sni_with_invalid_endpoint",
authInfo: credentials.TLSInfo{
State: tls.ConnectionState{
ServerName: "peering-sni",
},
},
peeringSNI: "peering-sni",
endpoint: "/some-invalid-endpoint",
expectErr: "invalid permissions to the specified endpoint",
},
{
name: "peering_sni_with_valid_endpoint",
authInfo: credentials.TLSInfo{
State: tls.ConnectionState{
ServerName: "peering-sni",
},
},
peeringSNI: "peering-sni",
endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint",
},
{
name: "non_peering_sni_always_allowed",
authInfo: credentials.TLSInfo{
State: tls.ConnectionState{
ServerName: "non-peering-sni",
},
},
peeringSNI: "peering-sni",
endpoint: "/some-non-peering-endpoint",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := restrictPeeringEndpoints(tc.authInfo, tc.peeringSNI, tc.endpoint)
if tc.expectErr == "" {
require.NoError(t, err)
} else {
require.Contains(t, err.Error(), tc.expectErr)
}
})
}
}

View File

@ -0,0 +1,78 @@
package middleware
import (
"fmt"
"net"
"google.golang.org/grpc/credentials"
)
var _ net.Listener = (*LabelledListener)(nil)
var _ net.Conn = (*LabelledConn)(nil)
type Protocol int
var (
ProtocolPlaintext Protocol = 0
ProtocolTLS Protocol = 1
)
// LabelledListener wraps a listener and attaches pre-specified
// fields to each spawned connection.
type LabelledListener struct {
net.Listener
Protocol Protocol
}
func (l LabelledListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if conn != nil {
conn = LabelledConn{conn, l.Protocol}
}
return conn, err
}
// LabelledConn wraps a connection and provides extra metadata fields.
type LabelledConn struct {
net.Conn
protocol Protocol
}
var _ credentials.TransportCredentials = (*optionalTransportCredentials)(nil)
// optionalTransportCredentials provides a way to selectively perform a TLS handshake
// based on metadata extracted from the underlying connection object.
type optionalTransportCredentials struct {
credentials.TransportCredentials
logger Logger
}
func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger Logger) credentials.TransportCredentials {
return &optionalTransportCredentials{creds, logger}
}
// ServerHandshake will attempt to detect the underlying connection protocol (TLS or Plaintext)
// based on metadata attached to the underlying connection. If TLS is detected, then a handshake
// will be performed, and the corresponding AuthInfo will be attached to the gRPC context.
// For plaintext connections, this is effectively a no-op, since there is no TLS info to attach.
// If the underlying connection is not a LabelledConn with a valid protocol, then this method will
// panic and prevent the gRPC connection from successfully progressing further.
func (tc *optionalTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
// This should always be a LabelledConn, so no check is necessary.
nc := conn.(LabelledConn)
switch nc.protocol {
case ProtocolPlaintext:
// This originated from a plaintext listener, so do not use TLS auth.
return nc, nil, nil
case ProtocolTLS:
// This originated from a TLS listener, so it should have a full handshake performed.
c, ai, err := tc.TransportCredentials.ServerHandshake(conn)
if err == nil && ai == nil {
// This should not be possible, but ensure that it's non-nil for safety.
return nil, nil, fmt.Errorf("missing auth info after handshake")
}
return c, ai, err
default:
return nil, nil, fmt.Errorf("invalid protocol for grpc connection")
}
}

View File

@ -0,0 +1,84 @@
package middleware
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/credentials"
)
type fakeTransportCredentials struct {
credentials.TransportCredentials
callback func(conn net.Conn) (net.Conn, credentials.AuthInfo, error)
}
func (f fakeTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return f.callback(conn)
}
func TestGRPCMiddleware_optionalTransportCredentials_ServerHandshake(t *testing.T) {
plainConn := LabelledConn{protocol: ProtocolPlaintext}
tlsConn := LabelledConn{protocol: ProtocolTLS}
tests := []struct {
name string
conn net.Conn
callback func(conn net.Conn) (net.Conn, credentials.AuthInfo, error)
expectConn net.Conn
expectAuthInfo credentials.AuthInfo
expectErr string
}{
{
name: "plaintext_noop",
conn: plainConn,
expectConn: plainConn,
expectAuthInfo: nil,
},
{
name: "tls_with_missing_auth",
conn: tlsConn,
callback: func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, nil, nil
},
expectConn: nil,
expectAuthInfo: nil,
expectErr: "missing auth info after handshake",
},
{
name: "tls_success",
conn: tlsConn,
callback: func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return conn, credentials.TLSInfo{}, nil
},
expectConn: tlsConn,
expectAuthInfo: credentials.TLSInfo{},
expectErr: "",
},
{
name: "invalid_protocol",
conn: LabelledConn{protocol: -1},
expectConn: nil,
expectAuthInfo: nil,
expectErr: "invalid protocol for grpc connection",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
creds := optionalTransportCredentials{
TransportCredentials: fakeTransportCredentials{
callback: tc.callback,
},
}
conn, authInfo, err := creds.ServerHandshake(tc.conn)
require.Equal(t, tc.expectConn, conn)
require.Equal(t, tc.expectAuthInfo, authInfo)
if tc.expectErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectErr)
}
})
}
}

View File

@ -2,7 +2,6 @@ package peering_test
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
@ -29,6 +28,7 @@ import (
"github.com/hashicorp/consul/agent/grpc-external/limiter"
grpc "github.com/hashicorp/consul/agent/grpc-internal"
"github.com/hashicorp/consul/agent/grpc-internal/resolver"
agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/agent/router"
"github.com/hashicorp/consul/agent/rpc/middleware"
@ -1546,7 +1546,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta()
deps := newDefaultDeps(t, conf)
externalGRPCServer := external.NewServer(deps.Logger, nil)
externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator)
server, err := consul.NewServer(conf, deps, externalGRPCServer)
require.NoError(t, err)
@ -1563,8 +1563,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer {
ln, err := net.Listen("tcp", grpcAddr)
require.NoError(t, err)
ln = tls.NewListener(ln, deps.TLSConfigurator.IncomingGRPCConfig())
ln = agentmiddleware.LabelledListener{Listener: ln, Protocol: agentmiddleware.ProtocolTLS}
go func() {
_ = externalGRPCServer.Serve(ln)

View File

@ -770,8 +770,16 @@ func (c *Configurator) IncomingGRPCConfig() *tls.Config {
c.base.GRPC,
c.base.GRPC.VerifyIncoming,
)
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
return c.IncomingGRPCConfig(), nil
config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
conf := c.IncomingGRPCConfig()
// Do not enforce mutualTLS for peering SNI entries. This is necessary, because
// there is no way to specify an mTLS cert when establishing a peering connection.
// This bypass is only safe because the `grpc-middleware.AuthInterceptor` explicitly
// restricts the list of endpoints that can be called when peering SNI is present.
if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName {
conf.ClientAuth = tls.NoClientCert
}
return conf, nil
}
config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName {
@ -950,6 +958,12 @@ func (c *Configurator) AutoEncryptCert() *x509.Certificate {
return cert
}
func (c *Configurator) PeeringServerName() string {
c.lock.RLock()
defer c.lock.RUnlock()
return c.autoTLS.peeringServerName
}
func (c *Configurator) log(name string) {
if c.logger != nil && c.logger.IsTrace() {
c.logger.Trace(name, "version", atomic.LoadUint64(&c.version))