diff --git a/agent/consul/config.go b/agent/consul/config.go index 431647565..c1b2451ab 100644 --- a/agent/consul/config.go +++ b/agent/consul/config.go @@ -470,6 +470,9 @@ type Config struct { // AutoEncrypt.Sign requests. AutoEncryptAllowTLS bool + // TODO: godoc, set this value from Agent + EnableGRPCServer bool + // Embedded Consul Enterprise specific configuration *EnterpriseConfig } diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index 0a520dcee..ac1096292 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -188,6 +188,9 @@ func (s *Server) handleConn(conn net.Conn, isTLS bool) { conn = tls.Server(conn, s.tlsConfigurator.IncomingInsecureRPCConfig()) s.handleInsecureConn(conn) + case pool.RPCGRPC: + s.grpcHandler.Handle(conn) + default: if !s.handleEnterpriseRPCConn(typ, conn, isTLS) { s.rpcLogger().Error("unrecognized RPC byte", @@ -254,6 +257,9 @@ func (s *Server) handleNativeTLS(conn net.Conn) { case pool.ALPN_RPCSnapshot: s.handleSnapshotConn(tlsConn) + case pool.ALPN_RPCGRPC: + s.grpcHandler.Handle(conn) + case pool.ALPN_WANGossipPacket: if err := s.handleALPN_WANGossipPacketStream(tlsConn); err != nil && err != io.EOF { s.rpcLogger().Error( diff --git a/agent/consul/server.go b/agent/consul/server.go index c1c1a6d76..2a496d962 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -26,6 +26,7 @@ import ( "github.com/hashicorp/consul/agent/consul/fsm" "github.com/hashicorp/consul/agent/consul/state" "github.com/hashicorp/consul/agent/consul/usagemetrics" + "github.com/hashicorp/consul/agent/grpc" "github.com/hashicorp/consul/agent/metadata" "github.com/hashicorp/consul/agent/pool" "github.com/hashicorp/consul/agent/router" @@ -239,8 +240,9 @@ type Server struct { rpcConnLimiter connlimit.Limiter // Listener is used to listen for incoming connections - Listener net.Listener - rpcServer *rpc.Server + Listener net.Listener + grpcHandler connHandler + rpcServer *rpc.Server // insecureRPCServer is a RPC server that is configure with // IncomingInsecureRPCConfig to allow clients to call AutoEncrypt.Sign @@ -314,6 +316,12 @@ type Server struct { EnterpriseServer } +type connHandler interface { + Run() error + Handle(conn net.Conn) + Shutdown() error +} + // NewServer is used to construct a new Consul server from the configuration // and extra options, potentially returning an error. func NewServer(config *Config, options ...ConsulOption) (*Server, error) { @@ -603,6 +611,8 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { } go reporter.Run(&lib.StopChannelContext{StopCh: s.shutdownCh}) + s.grpcHandler = newGRPCHandlerFromConfig(logger, config) + // Initialize Autopilot. This must happen before starting leadership monitoring // as establishing leadership could attempt to use autopilot and cause a panic. s.initAutopilot(config) @@ -612,6 +622,11 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { go s.monitorLeadership() // Start listening for RPC requests. + go func() { + if err := s.grpcHandler.Run(); err != nil { + s.logger.Error("gRPC server failed", "error", err) + } + }() go s.listen(s.Listener) // Start listeners for any segments with separate RPC listeners. @@ -625,6 +640,14 @@ func NewServer(config *Config, options ...ConsulOption) (*Server, error) { return s, nil } +func newGRPCHandlerFromConfig(logger hclog.Logger, config *Config) connHandler { + if !config.EnableGRPCServer { + return grpc.NoOpHandler{Logger: logger} + } + + return grpc.NewHandler(config.RPCAddr) +} + func (s *Server) connectCARootsMonitor(ctx context.Context) { for { ws := memdb.NewWatchSet() @@ -949,6 +972,12 @@ func (s *Server) Shutdown() error { s.Listener.Close() } + if s.grpcHandler != nil { + if err := s.grpcHandler.Shutdown(); err != nil { + s.logger.Warn("failed to stop gRPC server", "error", err) + } + } + // Close the connection pool if s.connPool != nil { s.connPool.Shutdown() diff --git a/agent/grpc/handler.go b/agent/grpc/handler.go new file mode 100644 index 000000000..ab23537ff --- /dev/null +++ b/agent/grpc/handler.go @@ -0,0 +1,109 @@ +/* +Package grpc provides a Handler and client for agent gRPC connections. +*/ +package grpc + +import ( + "fmt" + "net" + + "google.golang.org/grpc" +) + +// NewHandler returns a gRPC server that accepts connections from Handle(conn). +func NewHandler(addr net.Addr) *Handler { + // We don't need to pass tls.Config to the server since it's multiplexed + // behind the RPC listener, which already has TLS configured. + srv := grpc.NewServer( + grpc.StatsHandler(&statsHandler{}), + grpc.StreamInterceptor((&activeStreamCounter{}).Intercept), + ) + + // TODO(streaming): add gRPC services to srv here + + return &Handler{ + srv: srv, + listener: &chanListener{addr: addr, conns: make(chan net.Conn)}, + } +} + +// Handler implements a handler for the rpc server listener, and the +// agent.Component interface for managing the lifecycle of the grpc.Server. +type Handler struct { + srv *grpc.Server + listener *chanListener +} + +// Handle the connection by sending it to a channel for the grpc.Server to receive. +func (h *Handler) Handle(conn net.Conn) { + h.listener.conns <- conn +} + +func (h *Handler) Run() error { + return h.srv.Serve(h.listener) +} + +func (h *Handler) Shutdown() error { + h.srv.Stop() + return nil +} + +// chanListener implements net.Listener for grpc.Server. +type chanListener struct { + conns chan net.Conn + addr net.Addr +} + +// Accept blocks until a connection is received from Handle, and then returns the +// connection. Accept implements part of the net.Listener interface for grpc.Server. +func (l *chanListener) Accept() (net.Conn, error) { + return <-l.conns, nil +} + +func (l *chanListener) Addr() net.Addr { + return l.addr +} + +// Close does nothing. The connections are managed by the caller. +func (l *chanListener) Close() error { + return nil +} + +// NoOpHandler implements the same methods as Handler, but performs no handling. +// It may be used in place of Handler to disable the grpc server. +type NoOpHandler struct { + Logger Logger +} + +type Logger interface { + Error(string, ...interface{}) +} + +func (h NoOpHandler) Handle(conn net.Conn) { + h.Logger.Error("gRPC conn opened but gRPC RPC is disabled, closing", + "conn", logConn(conn)) + _ = conn.Close() +} + +func (h NoOpHandler) Run() error { + return nil +} + +func (h NoOpHandler) Shutdown() error { + return nil +} + +// logConn is a local copy of github.com/hashicorp/memberlist.LogConn, to avoid +// a large dependency for a minor formatting function. +// logConn is used to keep log formatting consistent. +func logConn(conn net.Conn) string { + if conn == nil { + return "from=" + } + addr := conn.RemoteAddr() + if addr == nil { + return "from=" + } + + return fmt.Sprintf("from=%s", addr.String()) +} diff --git a/agent/grpc/internal/testservice/simple.pb.binary.go b/agent/grpc/internal/testservice/simple.pb.binary.go new file mode 100644 index 000000000..ef203aaa6 --- /dev/null +++ b/agent/grpc/internal/testservice/simple.pb.binary.go @@ -0,0 +1,28 @@ +// Code generated by protoc-gen-go-binary. DO NOT EDIT. +// source: agent/grpc/internal/testservice/simple.proto + +package testservice + +import ( + "github.com/golang/protobuf/proto" +) + +// MarshalBinary implements encoding.BinaryMarshaler +func (msg *Req) MarshalBinary() ([]byte, error) { + return proto.Marshal(msg) +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (msg *Req) UnmarshalBinary(b []byte) error { + return proto.Unmarshal(b, msg) +} + +// MarshalBinary implements encoding.BinaryMarshaler +func (msg *Resp) MarshalBinary() ([]byte, error) { + return proto.Marshal(msg) +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (msg *Resp) UnmarshalBinary(b []byte) error { + return proto.Unmarshal(b, msg) +} diff --git a/agent/grpc/internal/testservice/simple.pb.go b/agent/grpc/internal/testservice/simple.pb.go new file mode 100644 index 000000000..ee6ebc1ec --- /dev/null +++ b/agent/grpc/internal/testservice/simple.pb.go @@ -0,0 +1,691 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: agent/grpc/internal/testservice/simple.proto + +package testservice + +import ( + context "context" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Req struct { + Datacenter string `protobuf:"bytes,1,opt,name=Datacenter,proto3" json:"Datacenter,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Req) Reset() { *m = Req{} } +func (m *Req) String() string { return proto.CompactTextString(m) } +func (*Req) ProtoMessage() {} +func (*Req) Descriptor() ([]byte, []int) { + return fileDescriptor_3009a77c573f826d, []int{0} +} +func (m *Req) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Req) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Req.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Req) XXX_Merge(src proto.Message) { + xxx_messageInfo_Req.Merge(m, src) +} +func (m *Req) XXX_Size() int { + return m.Size() +} +func (m *Req) XXX_DiscardUnknown() { + xxx_messageInfo_Req.DiscardUnknown(m) +} + +var xxx_messageInfo_Req proto.InternalMessageInfo + +func (m *Req) GetDatacenter() string { + if m != nil { + return m.Datacenter + } + return "" +} + +type Resp struct { + ServerName string `protobuf:"bytes,1,opt,name=ServerName,proto3" json:"ServerName,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Resp) Reset() { *m = Resp{} } +func (m *Resp) String() string { return proto.CompactTextString(m) } +func (*Resp) ProtoMessage() {} +func (*Resp) Descriptor() ([]byte, []int) { + return fileDescriptor_3009a77c573f826d, []int{1} +} +func (m *Resp) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Resp) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Resp.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Resp) XXX_Merge(src proto.Message) { + xxx_messageInfo_Resp.Merge(m, src) +} +func (m *Resp) XXX_Size() int { + return m.Size() +} +func (m *Resp) XXX_DiscardUnknown() { + xxx_messageInfo_Resp.DiscardUnknown(m) +} + +var xxx_messageInfo_Resp proto.InternalMessageInfo + +func (m *Resp) GetServerName() string { + if m != nil { + return m.ServerName + } + return "" +} + +func init() { + proto.RegisterType((*Req)(nil), "testservice.Req") + proto.RegisterType((*Resp)(nil), "testservice.Resp") +} + +func init() { + proto.RegisterFile("agent/grpc/internal/testservice/simple.proto", fileDescriptor_3009a77c573f826d) +} + +var fileDescriptor_3009a77c573f826d = []byte{ + // 200 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x49, 0x4c, 0x4f, 0xcd, + 0x2b, 0xd1, 0x4f, 0x2f, 0x2a, 0x48, 0xd6, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, + 0x2f, 0x49, 0x2d, 0x2e, 0x29, 0x4e, 0x2d, 0x2a, 0xcb, 0x4c, 0x4e, 0xd5, 0x2f, 0xce, 0xcc, 0x2d, + 0xc8, 0x49, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x46, 0x92, 0x51, 0x52, 0xe5, 0x62, + 0x0e, 0x4a, 0x2d, 0x14, 0x92, 0xe3, 0xe2, 0x72, 0x49, 0x2c, 0x49, 0x4c, 0x4e, 0x05, 0xe9, 0x96, + 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x42, 0x12, 0x51, 0x52, 0xe3, 0x62, 0x09, 0x4a, 0x2d, 0x2e, + 0x00, 0xa9, 0x0b, 0x4e, 0x2d, 0x2a, 0x4b, 0x2d, 0xf2, 0x4b, 0xcc, 0x4d, 0x85, 0xa9, 0x43, 0x88, + 0x18, 0xe5, 0x72, 0xb1, 0x05, 0x83, 0xed, 0x12, 0x32, 0xe2, 0xe2, 0x0c, 0xce, 0xcf, 0x4d, 0x2d, + 0xc9, 0xc8, 0xcc, 0x4b, 0x17, 0x12, 0xd0, 0x43, 0xb2, 0x53, 0x2f, 0x28, 0xb5, 0x50, 0x4a, 0x10, + 0x4d, 0xa4, 0xb8, 0x40, 0x89, 0x41, 0x48, 0x9f, 0x8b, 0xc5, 0x2d, 0x27, 0xbf, 0x9c, 0x48, 0xe5, + 0x06, 0x8c, 0x4e, 0x02, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, 0x91, 0x1c, + 0xe3, 0x8c, 0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0x3f, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, + 0x61, 0xd3, 0x5e, 0xba, 0x13, 0x01, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// SimpleClient is the client API for Simple service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type SimpleClient interface { + Something(ctx context.Context, in *Req, opts ...grpc.CallOption) (*Resp, error) + Flow(ctx context.Context, in *Req, opts ...grpc.CallOption) (Simple_FlowClient, error) +} + +type simpleClient struct { + cc *grpc.ClientConn +} + +func NewSimpleClient(cc *grpc.ClientConn) SimpleClient { + return &simpleClient{cc} +} + +func (c *simpleClient) Something(ctx context.Context, in *Req, opts ...grpc.CallOption) (*Resp, error) { + out := new(Resp) + err := c.cc.Invoke(ctx, "/testservice.Simple/Something", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *simpleClient) Flow(ctx context.Context, in *Req, opts ...grpc.CallOption) (Simple_FlowClient, error) { + stream, err := c.cc.NewStream(ctx, &_Simple_serviceDesc.Streams[0], "/testservice.Simple/Flow", opts...) + if err != nil { + return nil, err + } + x := &simpleFlowClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type Simple_FlowClient interface { + Recv() (*Resp, error) + grpc.ClientStream +} + +type simpleFlowClient struct { + grpc.ClientStream +} + +func (x *simpleFlowClient) Recv() (*Resp, error) { + m := new(Resp) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// SimpleServer is the server API for Simple service. +type SimpleServer interface { + Something(context.Context, *Req) (*Resp, error) + Flow(*Req, Simple_FlowServer) error +} + +// UnimplementedSimpleServer can be embedded to have forward compatible implementations. +type UnimplementedSimpleServer struct { +} + +func (*UnimplementedSimpleServer) Something(ctx context.Context, req *Req) (*Resp, error) { + return nil, status.Errorf(codes.Unimplemented, "method Something not implemented") +} +func (*UnimplementedSimpleServer) Flow(req *Req, srv Simple_FlowServer) error { + return status.Errorf(codes.Unimplemented, "method Flow not implemented") +} + +func RegisterSimpleServer(s *grpc.Server, srv SimpleServer) { + s.RegisterService(&_Simple_serviceDesc, srv) +} + +func _Simple_Something_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Req) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SimpleServer).Something(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/testservice.Simple/Something", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SimpleServer).Something(ctx, req.(*Req)) + } + return interceptor(ctx, in, info, handler) +} + +func _Simple_Flow_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Req) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(SimpleServer).Flow(m, &simpleFlowServer{stream}) +} + +type Simple_FlowServer interface { + Send(*Resp) error + grpc.ServerStream +} + +type simpleFlowServer struct { + grpc.ServerStream +} + +func (x *simpleFlowServer) Send(m *Resp) error { + return x.ServerStream.SendMsg(m) +} + +var _Simple_serviceDesc = grpc.ServiceDesc{ + ServiceName: "testservice.Simple", + HandlerType: (*SimpleServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Something", + Handler: _Simple_Something_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Flow", + Handler: _Simple_Flow_Handler, + ServerStreams: true, + }, + }, + Metadata: "agent/grpc/internal/testservice/simple.proto", +} + +func (m *Req) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Req) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Req) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if len(m.Datacenter) > 0 { + i -= len(m.Datacenter) + copy(dAtA[i:], m.Datacenter) + i = encodeVarintSimple(dAtA, i, uint64(len(m.Datacenter))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *Resp) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Resp) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Resp) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if len(m.ServerName) > 0 { + i -= len(m.ServerName) + copy(dAtA[i:], m.ServerName) + i = encodeVarintSimple(dAtA, i, uint64(len(m.ServerName))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintSimple(dAtA []byte, offset int, v uint64) int { + offset -= sovSimple(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *Req) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Datacenter) + if l > 0 { + n += 1 + l + sovSimple(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *Resp) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.ServerName) + if l > 0 { + n += 1 + l + sovSimple(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func sovSimple(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozSimple(x uint64) (n int) { + return sovSimple(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *Req) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowSimple + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Req: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Req: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Datacenter", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowSimple + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthSimple + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthSimple + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Datacenter = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipSimple(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthSimple + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthSimple + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Resp) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowSimple + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Resp: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Resp: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ServerName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowSimple + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthSimple + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthSimple + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ServerName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipSimple(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthSimple + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthSimple + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipSimple(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowSimple + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowSimple + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowSimple + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthSimple + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthSimple + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowSimple + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipSimple(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthSimple + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthSimple = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowSimple = fmt.Errorf("proto: integer overflow") +) diff --git a/agent/grpc/internal/testservice/simple.proto b/agent/grpc/internal/testservice/simple.proto new file mode 100644 index 000000000..bffa86def --- /dev/null +++ b/agent/grpc/internal/testservice/simple.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package testservice; + +// Simple service is used to test gRPC plumbing. +service Simple { + rpc Something(Req) returns (Resp) {} + rpc Flow(Req) returns (stream Resp) {} +} + +message Req { + string Datacenter = 1; +} + +message Resp { + string ServerName = 1; +} \ No newline at end of file diff --git a/agent/grpc/stats.go b/agent/grpc/stats.go new file mode 100644 index 000000000..cbf443878 --- /dev/null +++ b/agent/grpc/stats.go @@ -0,0 +1,81 @@ +package grpc + +import ( + "context" + "sync/atomic" + + "github.com/armon/go-metrics" + "google.golang.org/grpc" + "google.golang.org/grpc/stats" +) + +// statsHandler is a grpc/stats.StatsHandler which emits connection and +// request metrics to go-metrics. +type statsHandler struct { + activeConns uint64 // must be 8-byte aligned for atomic access +} + +// TagRPC implements grpcStats.StatsHandler +func (c *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + // No-op + return ctx +} + +// HandleRPC implements grpcStats.StatsHandler +func (c *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + label := "server" + if s.IsClient() { + label = "client" + } + switch s.(type) { + case *stats.InHeader: + metrics.IncrCounter([]string{"grpc", label, "request"}, 1) + } +} + +// TagConn implements grpcStats.StatsHandler +func (c *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + // No-op + return ctx +} + +// HandleConn implements grpcStats.StatsHandler +func (c *statsHandler) HandleConn(_ context.Context, s stats.ConnStats) { + label := "server" + if s.IsClient() { + label = "client" + } + var count uint64 + switch s.(type) { + case *stats.ConnBegin: + count = atomic.AddUint64(&c.activeConns, 1) + case *stats.ConnEnd: + // Decrement! + count = atomic.AddUint64(&c.activeConns, ^uint64(0)) + } + metrics.SetGauge([]string{"grpc", label, "active_conns"}, float32(count)) +} + +type activeStreamCounter struct { + // count of the number of open streaming RPCs on a server. It is accessed + // atomically. + count uint64 +} + +// GRPCCountingStreamInterceptor is a grpc.ServerStreamInterceptor that emits a +// a metric of the count of open streams. +func (i *activeStreamCounter) Intercept( + srv interface{}, + ss grpc.ServerStream, + _ *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + count := atomic.AddUint64(&i.count, 1) + metrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(count)) + defer func() { + count := atomic.AddUint64(&i.count, ^uint64(0)) + metrics.SetGauge([]string{"grpc", "server", "active_streams"}, float32(count)) + }() + + return handler(srv, ss) +} diff --git a/agent/grpc/stats_test.go b/agent/grpc/stats_test.go new file mode 100644 index 000000000..cc9910070 --- /dev/null +++ b/agent/grpc/stats_test.go @@ -0,0 +1,124 @@ +package grpc + +import ( + "context" + "net" + "testing" + "time" + + "github.com/armon/go-metrics" + "github.com/hashicorp/consul/agent/grpc/internal/testservice" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" +) + +func TestHandler_EmitsStats(t *testing.T) { + sink := patchGlobalMetrics(t) + + addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + handler := NewHandler(addr) + + testservice.RegisterSimpleServer(handler.srv, &simple{}) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer lis.Close() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return handler.srv.Serve(lis) + }) + t.Cleanup(func() { + if err := handler.Shutdown(); err != nil { + t.Logf("grpc server shutdown: %v", err) + } + if err := g.Wait(); err != nil { + t.Logf("grpc server error: %v", err) + } + }) + + conn, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure()) + require.NoError(t, err) + defer conn.Close() + + client := testservice.NewSimpleClient(conn) + fClient, err := client.Flow(ctx, &testservice.Req{Datacenter: "mine"}) + require.NoError(t, err) + + // Wait for the first event so that we know the stream is sending. + _, err = fClient.Recv() + require.NoError(t, err) + + expectedCounter := []metricCall{ + {key: []string{"testing", "grpc", "server", "request"}, val: 1}, + } + require.Equal(t, expectedCounter, sink.incrCounterCalls) + expectedGauge := []metricCall{ + {key: []string{"testing", "grpc", "server", "active_conns"}, val: 1}, + {key: []string{"testing", "grpc", "server", "active_streams"}, val: 1}, + // TODO: why is the count reset to 0 before the client receives the second message? + {key: []string{"testing", "grpc", "server", "active_streams"}, val: 0}, + } + require.Equal(t, expectedGauge, sink.gaugeCalls) +} + +type simple struct { + name string +} + +func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error { + if err := flow.Send(&testservice.Resp{ServerName: "one"}); err != nil { + return err + } + if err := flow.Send(&testservice.Resp{ServerName: "two"}); err != nil { + return err + } + return nil +} + +func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { + return &testservice.Resp{ServerName: "the-fake-service-name"}, nil +} + +func patchGlobalMetrics(t *testing.T) *fakeMetricsSink { + t.Helper() + + sink := &fakeMetricsSink{} + cfg := &metrics.Config{ + ServiceName: "testing", + TimerGranularity: time.Millisecond, // Timers are in milliseconds + ProfileInterval: time.Second, // Poll runtime every second + FilterDefault: true, + } + _, err := metrics.NewGlobal(cfg, sink) + require.NoError(t, err) + t.Cleanup(func() { + _, err = metrics.NewGlobal(cfg, &metrics.BlackholeSink{}) + require.NoError(t, err, "failed to reset global metrics") + }) + return sink +} + +type fakeMetricsSink struct { + metrics.BlackholeSink + gaugeCalls []metricCall + incrCounterCalls []metricCall +} + +func (f *fakeMetricsSink) SetGaugeWithLabels(key []string, val float32, labels []metrics.Label) { + f.gaugeCalls = append(f.gaugeCalls, metricCall{key: key, val: val, labels: labels}) +} + +func (f *fakeMetricsSink) IncrCounterWithLabels(key []string, val float32, labels []metrics.Label) { + f.incrCounterCalls = append(f.incrCounterCalls, metricCall{key: key, val: val, labels: labels}) +} + +type metricCall struct { + key []string + val float32 + labels []metrics.Label +} diff --git a/agent/pool/conn.go b/agent/pool/conn.go index 8a046fa4c..79731953b 100644 --- a/agent/pool/conn.go +++ b/agent/pool/conn.go @@ -40,23 +40,24 @@ const ( // that is supported and it might be the only one there // ever is. RPCTLSInsecure = 7 + RPCGRPC = 8 - // RPCMaxTypeValue is the maximum rpc type byte value currently used for - // the various protocols riding over our "rpc" port. + // RPCMaxTypeValue is the maximum rpc type byte value currently used for the + // various protocols riding over our "rpc" port. // - // Currently our 0-7 values are mutually exclusive with any valid first - // byte of a TLS header. The first TLS header byte will begin with a TLS - // content type and the values 0-19 are all explicitly unassigned and - // marked as requiring coordination. RFC 7983 does the marking and goes - // into some details about multiplexing connections and identifying TLS. + // Currently our 0-8 values are mutually exclusive with any valid first byte + // of a TLS header. The first TLS header byte will begin with a TLS content + // type and the values 0-19 are all explicitly unassigned and marked as + // requiring coordination. RFC 7983 does the marking and goes into some + // details about multiplexing connections and identifying TLS. // // We use this value to determine if the incoming request is actual real - // native TLS (where we can demultiplex based on ALPN protocol) or our - // older type-byte system when new connections are established. + // native TLS (where we can de-multiplex based on ALPN protocol) or our older + // type-byte system when new connections are established. // // NOTE: if you add new RPCTypes beyond this value, you must similarly bump // this value. - RPCMaxTypeValue = 7 + RPCMaxTypeValue = 8 ) const ( @@ -66,6 +67,7 @@ const ( ALPN_RPCMultiplexV2 = "consul/rpc-multi" // RPCMultiplexV2 ALPN_RPCSnapshot = "consul/rpc-snapshot" // RPCSnapshot ALPN_RPCGossip = "consul/rpc-gossip" // RPCGossip + ALPN_RPCGRPC = "consul/rpc-grpc" // RPCGRPC // wan federation additions ALPN_WANGossipPacket = "consul/wan-gossip/packet" ALPN_WANGossipStream = "consul/wan-gossip/stream"