82 lines
2.6 KiB
Go
82 lines
2.6 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
|
|
"google.golang.org/grpc/credentials"
|
|
)
|
|
|
|
var _ net.Listener = (*LabelledListener)(nil)
|
|
var _ net.Conn = (*LabelledConn)(nil)
|
|
|
|
type Protocol int
|
|
|
|
var (
|
|
ProtocolPlaintext Protocol = 0
|
|
ProtocolTLS Protocol = 1
|
|
)
|
|
|
|
// LabelledListener wraps a listener and attaches pre-specified
|
|
// fields to each spawned connection.
|
|
type LabelledListener struct {
|
|
net.Listener
|
|
Protocol Protocol
|
|
}
|
|
|
|
func (l LabelledListener) Accept() (net.Conn, error) {
|
|
conn, err := l.Listener.Accept()
|
|
if conn != nil {
|
|
conn = LabelledConn{conn, l.Protocol}
|
|
}
|
|
return conn, err
|
|
}
|
|
|
|
// LabelledConn wraps a connection and provides extra metadata fields.
|
|
type LabelledConn struct {
|
|
net.Conn
|
|
protocol Protocol
|
|
}
|
|
|
|
var _ credentials.TransportCredentials = (*optionalTransportCredentials)(nil)
|
|
|
|
// optionalTransportCredentials provides a way to selectively perform a TLS handshake
|
|
// based on metadata extracted from the underlying connection object.
|
|
type optionalTransportCredentials struct {
|
|
credentials.TransportCredentials
|
|
logger Logger
|
|
}
|
|
|
|
func NewOptionalTransportCredentials(creds credentials.TransportCredentials, logger Logger) credentials.TransportCredentials {
|
|
return &optionalTransportCredentials{creds, logger}
|
|
}
|
|
|
|
// ServerHandshake will attempt to detect the underlying connection protocol (TLS or Plaintext)
|
|
// based on metadata attached to the underlying connection. If TLS is detected, then a handshake
|
|
// will be performed, and the corresponding AuthInfo will be attached to the gRPC context.
|
|
// For plaintext connections, this is effectively a no-op, since there is no TLS info to attach.
|
|
// If the underlying connection is not a LabelledConn with a valid protocol, then this method will
|
|
// panic and prevent the gRPC connection from successfully progressing further.
|
|
func (tc *optionalTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
|
// This should always be a LabelledConn, so no check is necessary.
|
|
nc := conn.(LabelledConn)
|
|
switch nc.protocol {
|
|
case ProtocolPlaintext:
|
|
// This originated from a plaintext listener, so do not use TLS auth.
|
|
return nc, nil, nil
|
|
case ProtocolTLS:
|
|
// This originated from a TLS listener, so it should have a full handshake performed.
|
|
c, ai, err := tc.TransportCredentials.ServerHandshake(conn)
|
|
if err == nil && ai == nil {
|
|
// This should not be possible, but ensure that it's non-nil for safety.
|
|
return nil, nil, fmt.Errorf("missing auth info after handshake")
|
|
}
|
|
return c, ai, err
|
|
default:
|
|
return nil, nil, fmt.Errorf("invalid protocol for grpc connection")
|
|
}
|
|
}
|