diff --git a/agent/agent.go b/agent/agent.go index 6d386cd9f..2489ee3e6 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -43,6 +43,7 @@ import ( "github.com/hashicorp/consul/agent/dns" external "github.com/hashicorp/consul/agent/grpc-external" grpcDNS "github.com/hashicorp/consul/agent/grpc-external/services/dns" + middleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/hcp/scada" libscada "github.com/hashicorp/consul/agent/hcp/scada" "github.com/hashicorp/consul/agent/local" @@ -563,6 +564,7 @@ func (a *Agent) Start(ctx context.Context) error { a.externalGRPCServer = external.NewServer( a.logger.Named("grpc.external"), metrics.Default(), + a.tlsConfigurator, ) if err := a.startLicenseManager(ctx); err != nil { @@ -855,7 +857,7 @@ func (a *Agent) listenAndServeGRPC() error { // Attempt to spawn listeners var listeners []net.Listener - start := func(port_name string, addrs []net.Addr, tlsConf *tls.Config) error { + start := func(port_name string, addrs []net.Addr, protocol middleware.Protocol) error { if len(addrs) < 1 { return nil } @@ -865,10 +867,7 @@ func (a *Agent) listenAndServeGRPC() error { return err } for i := range ln { - // Wrap with TLS, if provided. - if tlsConf != nil { - ln[i] = tls.NewListener(ln[i], tlsConf) - } + ln[i] = middleware.LabelledListener{Listener: ln[i], Protocol: protocol} listeners = append(listeners, ln[i]) } @@ -892,19 +891,19 @@ func (a *Agent) listenAndServeGRPC() error { // TODO: Simplify this block to only spawn plain-text after 1.14 when deprecated TLS support is removed. if a.config.GRPCPort > 0 { // Only allow the grpc port to spawn TLS connections if the other grpc_tls port is NOT defined. - var tlsConf *tls.Config = nil + protocol := middleware.ProtocolPlaintext if a.config.GRPCTLSPort <= 0 && a.tlsConfigurator.GRPCServerUseTLS() { a.logger.Warn("deprecated gRPC TLS configuration detected. Consider using `ports.grpc_tls` instead") - tlsConf = a.tlsConfigurator.IncomingGRPCConfig() + protocol = middleware.ProtocolTLS } - if err := start("grpc", a.config.GRPCAddrs, tlsConf); err != nil { + if err := start("grpc", a.config.GRPCAddrs, protocol); err != nil { closeListeners(listeners) return err } } // Only allow grpc_tls to spawn with a TLS listener. if a.config.GRPCTLSPort > 0 { - if err := start("grpc_tls", a.config.GRPCTLSAddrs, a.tlsConfigurator.IncomingGRPCConfig()); err != nil { + if err := start("grpc_tls", a.config.GRPCTLSAddrs, middleware.ProtocolTLS); err != nil { closeListeners(listeners) return err } diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index faf133070..1d328b8de 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -2,7 +2,6 @@ package consul import ( "context" - "crypto/tls" "crypto/x509" "fmt" "net" @@ -31,6 +30,7 @@ import ( "github.com/hashicorp/consul/agent/connect" external "github.com/hashicorp/consul/agent/grpc-external" + grpcmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/rpc/middleware" "github.com/hashicorp/consul/agent/structs" @@ -258,7 +258,9 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", grpcPort)) require.NoError(t, err) + protocol := grpcmiddleware.ProtocolPlaintext if grpcPort == srv.config.GRPCTLSPort || deps.TLSConfigurator.GRPCServerUseTLS() { + protocol = grpcmiddleware.ProtocolTLS // Set the internally managed server certificate. The cert manager is hooked to the Agent, so we need to bypass that here. if srv.config.PeeringEnabled && srv.config.ConnectEnabled { key, _ := srv.config.CAConfig.Config["PrivateKey"].(string) @@ -273,9 +275,8 @@ func testServerWithConfig(t *testing.T, configOpts ...func(*Config)) (string, *S } } - // Wrap the listener with TLS. - ln = tls.NewListener(ln, deps.TLSConfigurator.IncomingGRPCConfig()) } + ln = grpcmiddleware.LabelledListener{Listener: ln, Protocol: protocol} go func() { _ = srv.externalGRPCServer.Serve(ln) @@ -329,7 +330,7 @@ func newServerWithDeps(t *testing.T, c *Config, deps Deps) (*Server, error) { oldNotify() } } - grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil) + grpcServer := external.NewServer(deps.Logger.Named("grpc.external"), nil, deps.TLSConfigurator) srv, err := NewServer(c, deps, grpcServer) if err != nil { return nil, err diff --git a/agent/grpc-external/server.go b/agent/grpc-external/server.go index 59ca0dde2..dd0186d48 100644 --- a/agent/grpc-external/server.go +++ b/agent/grpc-external/server.go @@ -7,9 +7,11 @@ import ( middleware "github.com/grpc-ecosystem/go-grpc-middleware" recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" + "github.com/hashicorp/consul/tlsutil" ) var ( @@ -21,25 +23,34 @@ var ( // NewServer constructs a gRPC server for the external gRPC port, to which // handlers can be registered. -func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics) *grpc.Server { +func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics, tls *tlsutil.Configurator) *grpc.Server { if metricsObj == nil { metricsObj = metrics.Default() } recoveryOpts := agentmiddleware.PanicHandlerMiddlewareOpts(logger) + unaryInterceptors := []grpc.UnaryServerInterceptor{ + // Add middlware interceptors to recover in case of panics. + recovery.UnaryServerInterceptor(recoveryOpts...), + } + streamInterceptors := []grpc.StreamServerInterceptor{ + // Add middlware interceptors to recover in case of panics. + recovery.StreamServerInterceptor(recoveryOpts...), + agentmiddleware.NewActiveStreamCounter(metricsObj, metricsLabels).Intercept, + } + + if tls != nil { + // Attach TLS middleware if TLS is provided. + authInterceptor := agentmiddleware.AuthInterceptor{TLS: tls, Logger: logger} + unaryInterceptors = append(unaryInterceptors, authInterceptor.InterceptUnary) + streamInterceptors = append(streamInterceptors, authInterceptor.InterceptStream) + } opts := []grpc.ServerOption{ grpc.MaxConcurrentStreams(2048), grpc.MaxRecvMsgSize(50 * 1024 * 1024), grpc.StatsHandler(agentmiddleware.NewStatsHandler(metricsObj, metricsLabels)), - middleware.WithUnaryServerChain( - // Add middlware interceptors to recover in case of panics. - recovery.UnaryServerInterceptor(recoveryOpts...), - ), - middleware.WithStreamServerChain( - // Add middlware interceptors to recover in case of panics. - recovery.StreamServerInterceptor(recoveryOpts...), - agentmiddleware.NewActiveStreamCounter(metricsObj, metricsLabels).Intercept, - ), + middleware.WithUnaryServerChain(unaryInterceptors...), + middleware.WithStreamServerChain(streamInterceptors...), grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ // This must be less than the keealive.ClientParameters Time setting, otherwise // the server will disconnect the client for sending too many keepalive pings. @@ -47,5 +58,13 @@ func NewServer(logger agentmiddleware.Logger, metricsObj *metrics.Metrics) *grpc MinTime: 15 * time.Second, }), } + + if tls != nil { + // Attach TLS credentials, if provided. + tlsCreds := agentmiddleware.NewOptionalTransportCredentials( + credentials.NewTLS(tls.IncomingGRPCConfig()), + logger) + opts = append(opts, grpc.Creds(tlsCreds)) + } return grpc.NewServer(opts...) } diff --git a/agent/grpc-external/stats_test.go b/agent/grpc-external/stats_test.go index 8f5dffb67..deba9ee64 100644 --- a/agent/grpc-external/stats_test.go +++ b/agent/grpc-external/stats_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-hclog" + grpcmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/grpc-middleware/testutil" "github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice" "github.com/hashicorp/consul/proto/prototest" @@ -22,12 +23,13 @@ import ( func TestServer_EmitsStats(t *testing.T) { sink, metricsObj := testutil.NewFakeSink(t) - srv := NewServer(hclog.Default(), metricsObj) + srv := NewServer(hclog.Default(), metricsObj, nil) testservice.RegisterSimpleServer(srv, &testservice.Simple{}) lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + lis = grpcmiddleware.LabelledListener{Listener: lis, Protocol: grpcmiddleware.ProtocolPlaintext} ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) diff --git a/agent/grpc-middleware/auth_interceptor.go b/agent/grpc-middleware/auth_interceptor.go new file mode 100644 index 000000000..048b0b272 --- /dev/null +++ b/agent/grpc-middleware/auth_interceptor.go @@ -0,0 +1,85 @@ +package middleware + +import ( + "context" + "fmt" + "strings" + + "github.com/hashicorp/consul/tlsutil" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +const AllowedPeerEndpointPrefix = "/hashicorp.consul.internal.peerstream.PeerStreamService/" + +// AuthInterceptor provides gRPC interceptors for restricting endpoint access based +// on SNI. If the connection is plaintext, this filter will not activate, and the +// connection will be allowed to proceed. +type AuthInterceptor struct { + TLS *tlsutil.Configurator + Logger Logger +} + +// InterceptUnary prevents non-streaming gRPC calls from calling certain endpoints, +// based on the SNI information. +func (a *AuthInterceptor) InterceptUnary( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (interface{}, error) { + p, ok := peer.FromContext(ctx) + if !ok { + return nil, fmt.Errorf("unable to fetch peer info from grpc context") + } + err := restrictPeeringEndpoints(p.AuthInfo, a.TLS.PeeringServerName(), info.FullMethod) + if err != nil { + return nil, err + } + return handler(ctx, req) +} + +// InterceptUnary prevents streaming gRPC calls from calling certain endpoints, +// based on the SNI information. +func (a *AuthInterceptor) InterceptStream( + srv interface{}, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + p, ok := peer.FromContext(ss.Context()) + if !ok { + return fmt.Errorf("unable to fetch peer info from grpc context") + } + err := restrictPeeringEndpoints(p.AuthInfo, a.TLS.PeeringServerName(), info.FullMethod) + if err != nil { + return err + } + return handler(srv, ss) +} + +// restrictPeeringEndpoints will return an error if a peering TLS connection attempts to call +// a non-peering endpoint. This is necessary, because the peer streaming workflow does not +// present a mutual TLS certificate, and is allowed to bypass the `tls.grpc.verify_incoming` +// check as a special case. See the `tlsutil.Configurator` for this bypass. +func restrictPeeringEndpoints(authInfo credentials.AuthInfo, peeringSNI string, endpoint string) error { + // This indicates a plaintext connection. + if authInfo == nil { + return nil + } + // Otherwise attempt to check the AuthInfo for TLS credentials. + tlsAuth, ok := authInfo.(credentials.TLSInfo) + if !ok { + return status.Error(codes.Unauthenticated, "invalid transport credentials") + } + if tlsAuth.State.ServerName == peeringSNI { + // Prevent any calls, except those in the PeerStreamService + if !strings.HasPrefix(endpoint, AllowedPeerEndpointPrefix) { + return status.Error(codes.PermissionDenied, "invalid permissions to the specified endpoint") + } + } + return nil +} diff --git a/agent/grpc-middleware/auth_interceptor_test.go b/agent/grpc-middleware/auth_interceptor_test.go new file mode 100644 index 000000000..f0e8c9e46 --- /dev/null +++ b/agent/grpc-middleware/auth_interceptor_test.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "crypto/tls" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/credentials" +) + +type invalidAuthInfo struct{} + +func (i invalidAuthInfo) AuthType() string { + return "invalid." +} + +func TestGRPCMiddleware_restrictPeeringEndpoints(t *testing.T) { + + tests := []struct { + name string + authInfo credentials.AuthInfo + peeringSNI string + endpoint string + expectErr string + }{ + { + name: "plaintext_always_allowed", + authInfo: nil, + peeringSNI: "expected-server-sni", + endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint", + }, + { + name: "deny_invalid_credentials", + authInfo: invalidAuthInfo{}, + expectErr: "invalid transport credentials", + }, + { + name: "peering_sni_with_invalid_endpoint", + authInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + ServerName: "peering-sni", + }, + }, + peeringSNI: "peering-sni", + endpoint: "/some-invalid-endpoint", + expectErr: "invalid permissions to the specified endpoint", + }, + { + name: "peering_sni_with_valid_endpoint", + authInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + ServerName: "peering-sni", + }, + }, + peeringSNI: "peering-sni", + endpoint: "/hashicorp.consul.internal.peerstream.PeerStreamService/SomeEndpoint", + }, + { + name: "non_peering_sni_always_allowed", + authInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + ServerName: "non-peering-sni", + }, + }, + peeringSNI: "peering-sni", + endpoint: "/some-non-peering-endpoint", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := restrictPeeringEndpoints(tc.authInfo, tc.peeringSNI, tc.endpoint) + if tc.expectErr == "" { + require.NoError(t, err) + } else { + require.Contains(t, err.Error(), tc.expectErr) + } + }) + } +} diff --git a/agent/grpc-middleware/handshake.go b/agent/grpc-middleware/handshake.go new file mode 100644 index 000000000..8fe4c2739 --- /dev/null +++ b/agent/grpc-middleware/handshake.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "fmt" + "net" + + "google.golang.org/grpc/credentials" +) + +var _ net.Listener = (*LabelledListener)(nil) +var _ net.Conn = (*LabelledConn)(nil) + +type Protocol int + +var ( + ProtocolPlaintext Protocol = 0 + ProtocolTLS Protocol = 1 +) + +// LabelledListener wraps a listener and attaches pre-specified +// fields to each spawned connection. +type LabelledListener struct { + net.Listener + Protocol Protocol +} + +func (l LabelledListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if conn != nil { + conn = LabelledConn{conn, l.Protocol} + } + return conn, err +} + +// LabelledConn wraps a connection and provides extra metadata fields. +type LabelledConn struct { + net.Conn + protocol Protocol +} + +var _ credentials.TransportCredentials = (*optionalTransportCredentials)(nil) + +// optionalTransportCredentials provides a way to selectively perform a TLS handshake +// based on metadata extracted from the underlying connection object. +type optionalTransportCredentials struct { + credentials.TransportCredentials + logger Logger +} + +func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger Logger) credentials.TransportCredentials { + return &optionalTransportCredentials{creds, logger} +} + +// ServerHandshake will attempt to detect the underlying connection protocol (TLS or Plaintext) +// based on metadata attached to the underlying connection. If TLS is detected, then a handshake +// will be performed, and the corresponding AuthInfo will be attached to the gRPC context. +// For plaintext connections, this is effectively a no-op, since there is no TLS info to attach. +// If the underlying connection is not a LabelledConn with a valid protocol, then this method will +// panic and prevent the gRPC connection from successfully progressing further. +func (tc *optionalTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + // This should always be a LabelledConn, so no check is necessary. + nc := conn.(LabelledConn) + switch nc.protocol { + case ProtocolPlaintext: + // This originated from a plaintext listener, so do not use TLS auth. + return nc, nil, nil + case ProtocolTLS: + // This originated from a TLS listener, so it should have a full handshake performed. + c, ai, err := tc.TransportCredentials.ServerHandshake(conn) + if err == nil && ai == nil { + // This should not be possible, but ensure that it's non-nil for safety. + return nil, nil, fmt.Errorf("missing auth info after handshake") + } + return c, ai, err + default: + return nil, nil, fmt.Errorf("invalid protocol for grpc connection") + } +} diff --git a/agent/grpc-middleware/handshake_test.go b/agent/grpc-middleware/handshake_test.go new file mode 100644 index 000000000..78f4f4f87 --- /dev/null +++ b/agent/grpc-middleware/handshake_test.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/credentials" +) + +type fakeTransportCredentials struct { + credentials.TransportCredentials + callback func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) +} + +func (f fakeTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return f.callback(conn) +} + +func TestGRPCMiddleware_optionalTransportCredentials_ServerHandshake(t *testing.T) { + + plainConn := LabelledConn{protocol: ProtocolPlaintext} + tlsConn := LabelledConn{protocol: ProtocolTLS} + tests := []struct { + name string + conn net.Conn + callback func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) + expectConn net.Conn + expectAuthInfo credentials.AuthInfo + expectErr string + }{ + { + name: "plaintext_noop", + conn: plainConn, + expectConn: plainConn, + expectAuthInfo: nil, + }, + { + name: "tls_with_missing_auth", + conn: tlsConn, + callback: func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return conn, nil, nil + }, + expectConn: nil, + expectAuthInfo: nil, + expectErr: "missing auth info after handshake", + }, + { + name: "tls_success", + conn: tlsConn, + callback: func(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return conn, credentials.TLSInfo{}, nil + }, + expectConn: tlsConn, + expectAuthInfo: credentials.TLSInfo{}, + expectErr: "", + }, + { + name: "invalid_protocol", + conn: LabelledConn{protocol: -1}, + expectConn: nil, + expectAuthInfo: nil, + expectErr: "invalid protocol for grpc connection", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + creds := optionalTransportCredentials{ + TransportCredentials: fakeTransportCredentials{ + callback: tc.callback, + }, + } + conn, authInfo, err := creds.ServerHandshake(tc.conn) + require.Equal(t, tc.expectConn, conn) + require.Equal(t, tc.expectAuthInfo, authInfo) + if tc.expectErr == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + } + }) + } +} diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 2a576f0ef..85bcb8771 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -2,7 +2,6 @@ package peering_test import ( "context" - "crypto/tls" "encoding/base64" "encoding/json" "fmt" @@ -29,6 +28,7 @@ import ( "github.com/hashicorp/consul/agent/grpc-external/limiter" grpc "github.com/hashicorp/consul/agent/grpc-internal" "github.com/hashicorp/consul/agent/grpc-internal/resolver" + agentmiddleware "github.com/hashicorp/consul/agent/grpc-middleware" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" "github.com/hashicorp/consul/agent/rpc/middleware" @@ -1546,7 +1546,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer { conf.ACLResolverSettings.EnterpriseMeta = *conf.AgentEnterpriseMeta() deps := newDefaultDeps(t, conf) - externalGRPCServer := external.NewServer(deps.Logger, nil) + externalGRPCServer := external.NewServer(deps.Logger, nil, deps.TLSConfigurator) server, err := consul.NewServer(conf, deps, externalGRPCServer) require.NoError(t, err) @@ -1563,8 +1563,7 @@ func newTestServer(t *testing.T, cb func(conf *consul.Config)) testingServer { ln, err := net.Listen("tcp", grpcAddr) require.NoError(t, err) - - ln = tls.NewListener(ln, deps.TLSConfigurator.IncomingGRPCConfig()) + ln = agentmiddleware.LabelledListener{Listener: ln, Protocol: agentmiddleware.ProtocolTLS} go func() { _ = externalGRPCServer.Serve(ln) diff --git a/tlsutil/config.go b/tlsutil/config.go index b6d54b30b..027db6617 100644 --- a/tlsutil/config.go +++ b/tlsutil/config.go @@ -770,8 +770,16 @@ func (c *Configurator) IncomingGRPCConfig() *tls.Config { c.base.GRPC, c.base.GRPC.VerifyIncoming, ) - config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { - return c.IncomingGRPCConfig(), nil + config.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + conf := c.IncomingGRPCConfig() + // Do not enforce mutualTLS for peering SNI entries. This is necessary, because + // there is no way to specify an mTLS cert when establishing a peering connection. + // This bypass is only safe because the `grpc-middleware.AuthInterceptor` explicitly + // restricts the list of endpoints that can be called when peering SNI is present. + if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName { + conf.ClientAuth = tls.NoClientCert + } + return conf, nil } config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { if c.autoTLS.peeringServerName != "" && info.ServerName == c.autoTLS.peeringServerName { @@ -950,6 +958,12 @@ func (c *Configurator) AutoEncryptCert() *x509.Certificate { return cert } +func (c *Configurator) PeeringServerName() string { + c.lock.RLock() + defer c.lock.RUnlock() + return c.autoTLS.peeringServerName +} + func (c *Configurator) log(name string) { if c.logger != nil && c.logger.IsTrace() { c.logger.Trace(name, "version", atomic.LoadUint64(&c.version))