Remove non-gRPC request forwarding

This commit is contained in:
Jeff Mitchell 2017-05-24 09:34:59 -04:00
parent 44c2ef9601
commit 0d4e7fba69
21 changed files with 1699 additions and 486 deletions

View File

@ -1,6 +1,5 @@
// Code generated by protoc-gen-go. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: types.proto // source: types.proto
// DO NOT EDIT!
/* /*
Package forwarding is a generated protocol buffer package. Package forwarding is a generated protocol buffer package.

View File

@ -8,7 +8,6 @@ import (
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"os"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -143,23 +142,14 @@ func TestHTTP_Fallback_Disabled(t *testing.T) {
// This function recreates the fuzzy testing from transit to pipe a large // This function recreates the fuzzy testing from transit to pipe a large
// number of requests from the standbys to the active node. // number of requests from the standbys to the active node.
func TestHTTP_Forwarding_Stress(t *testing.T) { func TestHTTP_Forwarding_Stress(t *testing.T) {
testHTTP_Forwarding_Stress_Common(t, false, false, 50) testHTTP_Forwarding_Stress_Common(t, false, 50)
testHTTP_Forwarding_Stress_Common(t, false, true, 50) testHTTP_Forwarding_Stress_Common(t, true, 50)
testHTTP_Forwarding_Stress_Common(t, true, false, 50)
testHTTP_Forwarding_Stress_Common(t, true, true, 50)
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "")
} }
func testHTTP_Forwarding_Stress_Common(t *testing.T, rpc, parallel bool, num uint64) { func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) {
testPlaintext := "the quick brown fox" testPlaintext := "the quick brown fox"
testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA==" testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
if rpc {
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "1")
} else {
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "")
}
handler1 := http.NewServeMux() handler1 := http.NewServeMux()
handler2 := http.NewServeMux() handler2 := http.NewServeMux()
handler3 := http.NewServeMux() handler3 := http.NewServeMux()
@ -465,9 +455,9 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, rpc, parallel bool, num uin
wg.Wait() wg.Wait()
if totalOps == 0 || totalOps != successfulOps { if totalOps == 0 || totalOps != successfulOps {
t.Fatalf("total/successful ops zero or mismatch: %d/%d; rpc: %t, parallel: %t, num %d", totalOps, successfulOps, rpc, parallel, num) t.Fatalf("total/successful ops zero or mismatch: %d/%d; parallel: %t, num %d", totalOps, successfulOps, parallel, num)
} }
t.Logf("total operations tried: %d, total successful: %d; rpc: %t, parallel: %t, num %d", totalOps, successfulOps, rpc, parallel, num) t.Logf("total operations tried: %d, total successful: %d; parallel: %t, num %d", totalOps, successfulOps, parallel, num)
} }
// This tests TLS connection state forwarding by ensuring that we can use a // This tests TLS connection state forwarding by ensuring that we can use a

View File

@ -19,8 +19,6 @@ import (
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
"golang.org/x/net/http2"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/forwarding" "github.com/hashicorp/vault/helper/forwarding"
@ -50,11 +48,6 @@ type clusterKeyParams struct {
D *big.Int `json:"d" structs:"d" mapstructure:"d"` D *big.Int `json:"d" structs:"d" mapstructure:"d"`
} }
type activeConnection struct {
transport *http2.Transport
clusterAddr string
}
// Structure representing the storage entry that holds cluster information // Structure representing the storage entry that holds cluster information
type Cluster struct { type Cluster struct {
// Name of the cluster // Name of the cluster

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"os"
"testing" "testing"
"time" "time"
@ -176,18 +175,10 @@ func TestCluster_ForwardRequests(t *testing.T) {
// Make this nicer for tests // Make this nicer for tests
manualStepDownSleepPeriod = 5 * time.Second manualStepDownSleepPeriod = 5 * time.Second
testCluster_ForwardRequestsCommon(t, false) testCluster_ForwardRequestsCommon(t)
testCluster_ForwardRequestsCommon(t, true)
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "")
} }
func testCluster_ForwardRequestsCommon(t *testing.T, rpc bool) { func testCluster_ForwardRequestsCommon(t *testing.T) {
if rpc {
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "1")
} else {
os.Setenv("VAULT_USE_GRPC_REQUEST_FORWARDING", "")
}
handler1 := http.NewServeMux() handler1 := http.NewServeMux()
handler1.HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) { handler1.HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")

View File

@ -304,8 +304,6 @@ type Core struct {
// Shutdown success channel. We need this to be done serially to ensure // Shutdown success channel. We need this to be done serially to ensure
// that binds are removed before they might be reinstated. // that binds are removed before they might be reinstated.
clusterListenerShutdownSuccessCh chan struct{} clusterListenerShutdownSuccessCh chan struct{}
// Connection info containing a client and a current active address
requestForwardingConnection *activeConnection
// Write lock used to ensure that we don't have multiple connections adjust // Write lock used to ensure that we don't have multiple connections adjust
// this value at the same time // this value at the same time
requestForwardingConnectionLock sync.RWMutex requestForwardingConnectionLock sync.RWMutex

View File

@ -1,14 +1,12 @@
package vault package vault
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -143,17 +141,6 @@ func (c *Core) startForwarding() error {
} }
switch tlsConn.ConnectionState().NegotiatedProtocol { switch tlsConn.ConnectionState().NegotiatedProtocol {
case "h2":
if !ha {
conn.Close()
continue
}
c.logger.Trace("core: got h2 connection")
go fws.ServeConn(conn, &http2.ServeConnOpts{
Handler: wrappedHandler,
})
case "req_fw_sb-act_v1": case "req_fw_sb-act_v1":
if !ha { if !ha {
conn.Close() conn.Close()
@ -231,37 +218,18 @@ func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
return err return err
} }
switch os.Getenv("VAULT_USE_GRPC_REQUEST_FORWARDING") { // Set up grpc forwarding handling
case "": // It's not really insecure, but we have to dial manually to get the
// Set up normal HTTP forwarding handling // ALPN header right. It's just "insecure" because GRPC isn't managing
tlsConfig, err := c.ClusterTLSConfig() // the TLS state.
if err != nil { ctx, cancelFunc := context.WithCancel(context.Background())
c.logger.Error("core: error fetching cluster tls configuration when trying to create connection", "error", err) c.rpcClientConnCancelFunc = cancelFunc
return err c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "", nil)), grpc.WithInsecure())
} if err != nil {
tp := &http2.Transport{ c.logger.Error("core: err setting up forwarding rpc client", "error", err)
TLSClientConfig: tlsConfig, return err
}
c.requestForwardingConnection = &activeConnection{
transport: tp,
clusterAddr: clusterAddr,
}
default:
// Set up grpc forwarding handling
// It's not really insecure, but we have to dial manually to get the
// ALPN header right. It's just "insecure" because GRPC isn't managing
// the TLS state.
ctx, cancelFunc := context.WithCancel(context.Background())
c.rpcClientConnCancelFunc = cancelFunc
c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer("req_fw_sb-act_v1", "", nil)), grpc.WithInsecure())
if err != nil {
c.logger.Error("core: err setting up forwarding rpc client", "error", err)
return err
}
c.rpcForwardingClient = NewRequestForwardingClient(c.rpcClientConn)
} }
c.rpcForwardingClient = NewRequestForwardingClient(c.rpcClientConn)
return nil return nil
} }
@ -270,11 +238,6 @@ func (c *Core) clearForwardingClients() {
c.logger.Trace("core: clearing forwarding clients") c.logger.Trace("core: clearing forwarding clients")
defer c.logger.Trace("core: done clearing forwarding clients") defer c.logger.Trace("core: done clearing forwarding clients")
if c.requestForwardingConnection != nil {
c.requestForwardingConnection.transport.CloseIdleConnections()
c.requestForwardingConnection = nil
}
if c.rpcClientConnCancelFunc != nil { if c.rpcClientConnCancelFunc != nil {
c.rpcClientConnCancelFunc() c.rpcClientConnCancelFunc()
c.rpcClientConnCancelFunc = nil c.rpcClientConnCancelFunc = nil
@ -292,70 +255,36 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
c.requestForwardingConnectionLock.RLock() c.requestForwardingConnectionLock.RLock()
defer c.requestForwardingConnectionLock.RUnlock() defer c.requestForwardingConnectionLock.RUnlock()
switch os.Getenv("VAULT_USE_GRPC_REQUEST_FORWARDING") { if c.rpcForwardingClient == nil {
case "": return 0, nil, nil, ErrCannotForward
if c.requestForwardingConnection == nil { }
return 0, nil, nil, ErrCannotForward
}
if c.requestForwardingConnection.clusterAddr == "" { freq, err := forwarding.GenerateForwardedRequest(req)
return 0, nil, nil, ErrCannotForward if err != nil {
} c.logger.Error("core: error creating forwarding RPC request", "error", err)
return 0, nil, nil, fmt.Errorf("error creating forwarding RPC request")
}
if freq == nil {
c.logger.Error("core: got nil forwarding RPC request")
return 0, nil, nil, fmt.Errorf("got nil forwarding RPC request")
}
resp, err := c.rpcForwardingClient.ForwardRequest(context.Background(), freq, grpc.FailFast(true))
if err != nil {
c.logger.Error("core: error during forwarded RPC request", "error", err)
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request")
}
freq, err := forwarding.GenerateForwardedHTTPRequest(req, c.requestForwardingConnection.clusterAddr+"/cluster/local/forwarded-request") var header http.Header
if err != nil { if resp.HeaderEntries != nil {
c.logger.Error("core/ForwardRequest: error creating forwarded request", "error", err) header = make(http.Header)
return 0, nil, nil, fmt.Errorf("error creating forwarding request") for k, v := range resp.HeaderEntries {
} for _, j := range v.Values {
header.Add(k, j)
//resp, err := c.requestForwardingConnection.Do(freq)
resp, err := c.requestForwardingConnection.transport.RoundTrip(freq)
if err != nil {
return 0, nil, nil, err
}
defer resp.Body.Close()
// Read the body into a buffer so we can write it back out to the
// original requestor
buf := bytes.NewBuffer(nil)
_, err = buf.ReadFrom(resp.Body)
if err != nil {
return 0, nil, nil, err
}
return resp.StatusCode, resp.Header, buf.Bytes(), nil
default:
if c.rpcForwardingClient == nil {
return 0, nil, nil, ErrCannotForward
}
freq, err := forwarding.GenerateForwardedRequest(req)
if err != nil {
c.logger.Error("core/ForwardRequest: error creating forwarding RPC request", "error", err)
return 0, nil, nil, fmt.Errorf("error creating forwarding RPC request")
}
if freq == nil {
c.logger.Error("core/ForwardRequest: got nil forwarding RPC request")
return 0, nil, nil, fmt.Errorf("got nil forwarding RPC request")
}
resp, err := c.rpcForwardingClient.ForwardRequest(context.Background(), freq, grpc.FailFast(true))
if err != nil {
c.logger.Error("core/ForwardRequest: error during forwarded RPC request", "error", err)
return 0, nil, nil, fmt.Errorf("error during forwarding RPC request")
}
var header http.Header
if resp.HeaderEntries != nil {
header = make(http.Header)
for k, v := range resp.HeaderEntries {
for _, j := range v.Values {
header.Add(k, j)
}
} }
} }
return int(resp.StatusCode), header, resp.Body, nil
} }
return int(resp.StatusCode), header, resp.Body, nil
} }
// getGRPCDialer is used to return a dialer that has the correct TLS // getGRPCDialer is used to return a dialer that has the correct TLS

View File

@ -1,6 +1,5 @@
// Code generated by protoc-gen-go. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: request_forwarding_service.proto // source: request_forwarding_service.proto
// DO NOT EDIT!
/* /*
Package vault is a generated protocol buffer package. Package vault is a generated protocol buffer package.

View File

@ -26,9 +26,13 @@ Documentation
------------- -------------
See [API documentation](https://godoc.org/google.golang.org/grpc) for package and API descriptions and find examples in the [examples directory](examples/). See [API documentation](https://godoc.org/google.golang.org/grpc) for package and API descriptions and find examples in the [examples directory](examples/).
Performance
-----------
See the current benchmarks for some of the languages supported in [this dashboard](https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696).
Status Status
------ ------
GA General Availability [Google Cloud Platform Launch Stages](https://cloud.google.com/terms/launch-stages).
FAQ FAQ
--- ---

View File

@ -35,6 +35,7 @@ package grpc
import ( import (
"fmt" "fmt"
"net"
"sync" "sync"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -60,6 +61,10 @@ type BalancerConfig struct {
// use to dial to a remote load balancer server. The Balancer implementations // use to dial to a remote load balancer server. The Balancer implementations
// can ignore this if it does not need to talk to another party securely. // can ignore this if it does not need to talk to another party securely.
DialCreds credentials.TransportCredentials DialCreds credentials.TransportCredentials
// Dialer is the custom dialer the Balancer implementation can use to dial
// to a remote load balancer server. The Balancer implementations
// can ignore this if it doesn't need to talk to remote balancer.
Dialer func(context.Context, string) (net.Conn, error)
} }
// BalancerGetOptions configures a Get call. // BalancerGetOptions configures a Get call.
@ -385,6 +390,9 @@ func (rr *roundRobin) Notify() <-chan []Address {
func (rr *roundRobin) Close() error { func (rr *roundRobin) Close() error {
rr.mu.Lock() rr.mu.Lock()
defer rr.mu.Unlock() defer rr.mu.Unlock()
if rr.done {
return errBalancerClosed
}
rr.done = true rr.done = true
if rr.w != nil { if rr.w != nil {
rr.w.Close() rr.w.Close()

102
vendor/google.golang.org/grpc/call.go generated vendored
View File

@ -73,7 +73,10 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
} }
} }
for { for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil { if c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -93,11 +96,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
} }
// sendRequest writes out various information of an RPC such as Context and Message. // sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
}
defer func() { defer func() {
if err != nil { if err != nil {
// If err is connection error, t will be closed, no need to close stream here. // If err is connection error, t will be closed, no need to close stream here.
@ -120,7 +119,13 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
} }
outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload) outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
if err != nil { if err != nil {
return nil, Errorf(codes.Internal, "grpc: %v", err) return Errorf(codes.Internal, "grpc: %v", err)
}
if c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(outBuf) > *c.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize)
} }
err = t.Write(stream, outBuf, opts) err = t.Write(stream, outBuf, opts)
if err == nil && outPayload != nil { if err == nil && outPayload != nil {
@ -131,10 +136,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status. // recvResponse to get the final status.
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return nil, err return err
} }
// Sent successfully. // Sent successfully.
return stream, nil return nil
} }
// Invoke sends the RPC request on the wire and returns after response is received. // Invoke sends the RPC request on the wire and returns after response is received.
@ -149,14 +154,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
c := defaultCallInfo c := defaultCallInfo
if mc, ok := cc.getMethodConfig(method); ok { mc := cc.GetMethodConfig(method)
c.failFast = !mc.WaitForReady if mc.WaitForReady != nil {
if mc.Timeout > 0 { c.failFast = !*mc.WaitForReady
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, mc.Timeout)
defer cancel()
}
} }
if mc.Timeout != nil && *mc.Timeout >= 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
defer cancel()
}
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(&c); err != nil {
return toRPCErr(err) return toRPCErr(err)
@ -167,6 +176,10 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
o.after(&c) o.after(&c)
} }
}() }()
c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
if EnableTracing { if EnableTracing {
c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
defer c.traceInfo.tr.Finish() defer c.traceInfo.tr.Finish()
@ -183,26 +196,25 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
}() }()
} }
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler sh := cc.dopts.copts.StatsHandler
if sh != nil { if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
begin := &stats.Begin{ begin := &stats.Begin{
Client: true, Client: true,
BeginTime: time.Now(), BeginTime: time.Now(),
FailFast: c.failFast, FailFast: c.failFast,
} }
sh.HandleRPC(ctx, begin) sh.HandleRPC(ctx, begin)
} defer func() {
defer func() {
if sh != nil {
end := &stats.End{ end := &stats.End{
Client: true, Client: true,
EndTime: time.Now(), EndTime: time.Now(),
Error: e, Error: e,
} }
sh.HandleRPC(ctx, end) sh.HandleRPC(ctx, end)
} }()
}() }
topts := &transport.Options{ topts := &transport.Options{
Last: true, Last: true,
Delay: false, Delay: false,
@ -224,6 +236,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
if c.creds != nil {
callHdr.Creds = c.creds
}
gopts := BalancerGetOptions{ gopts := BalancerGetOptions{
BlockingWait: !c.failFast, BlockingWait: !c.failFast,
@ -246,19 +261,35 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
} }
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts) stream, err = t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
if put != nil { if put != nil {
if _, ok := err.(transport.ConnectionError); ok {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
}
put()
}
if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
continue
}
return toRPCErr(err)
}
err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, stream, t, args, topts)
if err != nil {
if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put() put()
put = nil
} }
// Retry a non-failfast RPC when // Retry a non-failfast RPC when
// i) there is a connection error; or // i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated. // ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
if c.failFast {
return toRPCErr(err)
}
continue continue
} }
return toRPCErr(err) return toRPCErr(err)
@ -266,13 +297,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
if err != nil { if err != nil {
if put != nil { if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put() put()
put = nil
} }
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
if c.failFast {
return toRPCErr(err)
}
continue continue
} }
return toRPCErr(err) return toRPCErr(err)
@ -282,8 +313,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
t.CloseStream(stream, nil) t.CloseStream(stream, nil)
if put != nil { if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{
bytesSent: stream.BytesSent(),
bytesReceived: stream.BytesReceived(),
})
put() put()
put = nil
} }
return stream.Status().Err() return stream.Status().Err()
} }

View File

@ -36,8 +36,8 @@ package grpc
import ( import (
"errors" "errors"
"fmt" "fmt"
"math"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -56,8 +56,7 @@ var (
ErrClientConnClosing = errors.New("grpc: the client connection is closing") ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the ClientConn cannot establish the // ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout. // underlying connections within the specified timeout.
// DEPRECATED: Please use context.DeadlineExceeded instead. This error will be // DEPRECATED: Please use context.DeadlineExceeded instead.
// removed in Q1 2017.
ErrClientConnTimeout = errors.New("grpc: timed out when dialing") ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
// errNoTransportSecurity indicates that there is no transport security // errNoTransportSecurity indicates that there is no transport security
@ -79,6 +78,8 @@ var (
errConnClosing = errors.New("grpc: the connection is closing") errConnClosing = errors.New("grpc: the connection is closing")
// errConnUnavailable indicates that the connection is unavailable. // errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable") errConnUnavailable = errors.New("grpc: the connection is unavailable")
// errBalancerClosed indicates that the balancer is closed.
errBalancerClosed = errors.New("grpc: balancer is closed")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
) )
@ -86,30 +87,54 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
unaryInt UnaryClientInterceptor unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor streamInt StreamClientInterceptor
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
bs backoffStrategy bs backoffStrategy
balancer Balancer balancer Balancer
block bool block bool
insecure bool insecure bool
timeout time.Duration timeout time.Duration
scChan <-chan ServiceConfig scChan <-chan ServiceConfig
copts transport.ConnectOptions copts transport.ConnectOptions
maxMsgSize int callOptions []CallOption
} }
const defaultClientMaxMsgSize = math.MaxInt32 const (
defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4
defaultClientMaxSendMessageSize = 1024 * 1024 * 4
)
// DialOption configures how we set up the connection. // DialOption configures how we set up the connection.
type DialOption func(*dialOptions) type DialOption func(*dialOptions)
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. // WithInitialWindowSize returns a DialOption which sets the value for initial window size on a stream.
func WithMaxMsgSize(s int) DialOption { // The lower bound for window size is 64K and any value smaller than that will be ignored.
func WithInitialWindowSize(s int32) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.maxMsgSize = s o.copts.InitialWindowSize = s
}
}
// WithInitialConnWindowSize returns a DialOption which sets the value for initial window size on a connection.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func WithInitialConnWindowSize(s int32) DialOption {
return func(o *dialOptions) {
o.copts.InitialConnWindowSize = s
}
}
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead.
func WithMaxMsgSize(s int) DialOption {
return WithDefaultCallOptions(MaxCallRecvMsgSize(s))
}
// WithDefaultCallOptions returns a DialOption which sets the default CallOptions for calls over the connection.
func WithDefaultCallOptions(cos ...CallOption) DialOption {
return func(o *dialOptions) {
o.callOptions = append(o.callOptions, cos...)
} }
} }
@ -204,7 +229,7 @@ func WithTransportCredentials(creds credentials.TransportCredentials) DialOption
} }
// WithPerRPCCredentials returns a DialOption which sets // WithPerRPCCredentials returns a DialOption which sets
// credentials which will place auth state on each outbound RPC. // credentials and places auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption { func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds) o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
@ -241,7 +266,7 @@ func WithStatsHandler(h stats.Handler) DialOption {
} }
} }
// FailOnNonTempDialError returns a DialOption that specified if gRPC fails on non-temporary dial errors. // FailOnNonTempDialError returns a DialOption that specifies if gRPC fails on non-temporary dial errors.
// If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network // If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network
// address and won't try to reconnect. // address and won't try to reconnect.
// The default value of FailOnNonTempDialError is false. // The default value of FailOnNonTempDialError is false.
@ -295,23 +320,29 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
} }
// DialContext creates a client connection to the given target. ctx can be used to // DialContext creates a client connection to the given target. ctx can be used to
// cancel or expire the pending connecting. Once this function returns, the // cancel or expire the pending connection. Once this function returns, the
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close // cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
// to terminate all the pending operations after this function returns. // to terminate all the pending operations after this function returns.
// This is the EXPERIMENTAL API.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[Address]*addrConn), conns: make(map[Address]*addrConn),
} }
cc.ctx, cc.cancel = context.WithCancel(context.Background()) cc.ctx, cc.cancel = context.WithCancel(context.Background())
cc.dopts.maxMsgSize = defaultClientMaxMsgSize
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
cc.mkp = cc.dopts.copts.KeepaliveParams cc.mkp = cc.dopts.copts.KeepaliveParams
grpcUA := "grpc-go/" + Version if cc.dopts.copts.Dialer == nil {
cc.dopts.copts.Dialer = newProxyDialer(
func(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, "tcp", addr)
},
)
}
if cc.dopts.copts.UserAgent != "" { if cc.dopts.copts.UserAgent != "" {
cc.dopts.copts.UserAgent += " " + grpcUA cc.dopts.copts.UserAgent += " " + grpcUA
} else { } else {
@ -336,15 +367,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
} }
}() }()
scSet := false
if cc.dopts.scChan != nil { if cc.dopts.scChan != nil {
// Wait for the initial service config. // Try to get an initial service config.
select { select {
case sc, ok := <-cc.dopts.scChan: case sc, ok := <-cc.dopts.scChan:
if ok { if ok {
cc.sc = sc cc.sc = sc
scSet = true
} }
case <-ctx.Done(): default:
return nil, ctx.Err()
} }
} }
// Set defaults. // Set defaults.
@ -375,6 +407,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
} }
config := BalancerConfig{ config := BalancerConfig{
DialCreds: credsClone, DialCreds: credsClone,
Dialer: cc.dopts.copts.Dialer,
} }
if err := cc.dopts.balancer.Start(target, config); err != nil { if err := cc.dopts.balancer.Start(target, config); err != nil {
waitC <- err waitC <- err
@ -406,7 +439,17 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
return nil, err return nil, err
} }
} }
if cc.dopts.scChan != nil && !scSet {
// Blocking wait for the initial service config.
select {
case sc, ok := <-cc.dopts.scChan:
if ok {
cc.sc = sc
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
if cc.dopts.scChan != nil { if cc.dopts.scChan != nil {
go cc.scWatcher() go cc.scWatcher()
} }
@ -616,12 +659,23 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error)
return nil return nil
} }
// TODO: Avoid the locking here. // GetMethodConfig gets the method config of the input method.
func (cc *ClientConn) getMethodConfig(method string) (m MethodConfig, ok bool) { // If there's an exact match for input method (i.e. /service/method), we return
// the corresponding MethodConfig.
// If there isn't an exact match for the input method, we look for the default config
// under the service (i.e /service/). If there is a default MethodConfig for
// the serivce, we return it.
// Otherwise, we return an empty MethodConfig.
func (cc *ClientConn) GetMethodConfig(method string) MethodConfig {
// TODO: Avoid the locking here.
cc.mu.RLock() cc.mu.RLock()
defer cc.mu.RUnlock() defer cc.mu.RUnlock()
m, ok = cc.sc.Methods[method] m, ok := cc.sc.Methods[method]
return if !ok {
i := strings.LastIndex(method, "/")
m, _ = cc.sc.Methods[method[:i+1]]
}
return m
} }
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
@ -662,6 +716,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
} }
if !ok { if !ok {
if put != nil { if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
put() put()
} }
return nil, nil, errConnClosing return nil, nil, errConnClosing
@ -669,6 +724,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait) t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil { if err != nil {
if put != nil { if put != nil {
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false})
put() put()
} }
return nil, nil, err return nil, nil, err
@ -842,11 +898,14 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
} }
ac.mu.Unlock() ac.mu.Unlock()
closeTransport = false closeTransport = false
timer := time.NewTimer(sleepTime - time.Since(connectTime))
select { select {
case <-time.After(sleepTime - time.Since(connectTime)): case <-timer.C:
case <-ac.ctx.Done(): case <-ac.ctx.Done():
timer.Stop()
return ac.ctx.Err() return ac.ctx.Err()
} }
timer.Stop()
continue continue
} }
ac.mu.Lock() ac.mu.Lock()

View File

@ -96,6 +96,7 @@ func (p protoCodec) Marshal(v interface{}) ([]byte, error) {
func (p protoCodec) Unmarshal(data []byte, v interface{}) error { func (p protoCodec) Unmarshal(data []byte, v interface{}) error {
cb := protoBufferPool.Get().(*cachedProtoBuffer) cb := protoBufferPool.Get().(*cachedProtoBuffer)
cb.SetBuf(data) cb.SetBuf(data)
v.(proto.Message).Reset()
err := cb.Unmarshal(v.(proto.Message)) err := cb.Unmarshal(v.(proto.Message))
cb.SetBuf(nil) cb.SetBuf(nil)
protoBufferPool.Put(cb) protoBufferPool.Put(cb)

56
vendor/google.golang.org/grpc/go16.go generated vendored Normal file
View File

@ -0,0 +1,56 @@
// +build go1.6,!go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"fmt"
"net"
"net/http"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
req.Cancel = ctx.Done()
if err := req.Write(conn); err != nil {
return fmt.Errorf("failed to write the HTTP request: %v", err)
}
return nil
}

55
vendor/google.golang.org/grpc/go17.go generated vendored Normal file
View File

@ -0,0 +1,55 @@
// +build go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"net"
"net/http"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
req = req.WithContext(ctx)
if err := req.Write(conn); err != nil {
return err
}
return nil
}

765
vendor/google.golang.org/grpc/grpclb.go generated vendored Normal file
View File

@ -0,0 +1,765 @@
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/naming"
)
// Client API for LoadBalancer service.
// Mostly copied from generated pb.go file.
// To avoid circular dependency.
type loadBalancerClient struct {
cc *ClientConn
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption) (*balanceLoadClientStream, error) {
desc := &StreamDesc{
StreamName: "BalanceLoad",
ServerStreams: true,
ClientStreams: true,
}
stream, err := NewClientStream(ctx, desc, c.cc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
x := &balanceLoadClientStream{stream}
return x, nil
}
type balanceLoadClientStream struct {
ClientStream
}
func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) {
m := new(lbpb.LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// AddressType indicates the address type returned by name resolution.
type AddressType uint8
const (
// Backend indicates the server is a backend server.
Backend AddressType = iota
// GRPCLB indicates the server is a grpclb load balancer.
GRPCLB
)
// AddrMetadataGRPCLB contains the information the name resolver for grpclb should provide. The
// name resolver used by the grpclb balancer is required to provide this type of metadata in
// its address updates.
type AddrMetadataGRPCLB struct {
// AddrType is the type of server (grpc load balancer or backend).
AddrType AddressType
// ServerName is the name of the grpc load balancer. Used for authentication.
ServerName string
}
// NewGRPCLBBalancer creates a grpclb load balancer.
func NewGRPCLBBalancer(r naming.Resolver) Balancer {
return &balancer{
r: r,
}
}
type remoteBalancerInfo struct {
addr string
// the server name used for authentication with the remote LB server.
name string
}
// grpclbAddrInfo consists of the information of a backend server.
type grpclbAddrInfo struct {
addr Address
connected bool
// dropForRateLimiting indicates whether this particular request should be
// dropped by the client for rate limiting.
dropForRateLimiting bool
// dropForLoadBalancing indicates whether this particular request should be
// dropped by the client for load balancing.
dropForLoadBalancing bool
}
type balancer struct {
r naming.Resolver
target string
mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher
addrCh chan []Address
rbs []remoteBalancerInfo
addrs []*grpclbAddrInfo
next int
waitCh chan struct{}
done bool
expTimer *time.Timer
rand *rand.Rand
clientStats lbpb.ClientStats
}
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
updates, err := w.Next()
if err != nil {
grpclog.Printf("grpclb: failed to get next addr update from watcher: %v", err)
return err
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return ErrClientConnClosing
}
for _, update := range updates {
switch update.Op {
case naming.Add:
var exist bool
for _, v := range b.rbs {
// TODO: Is the same addr with different server name a different balancer?
if update.Addr == v.addr {
exist = true
break
}
}
if exist {
continue
}
md, ok := update.Metadata.(*AddrMetadataGRPCLB)
if !ok {
// TODO: Revisit the handling here and may introduce some fallback mechanism.
grpclog.Printf("The name resolution contains unexpected metadata %v", update.Metadata)
continue
}
switch md.AddrType {
case Backend:
// TODO: Revisit the handling here and may introduce some fallback mechanism.
grpclog.Printf("The name resolution does not give grpclb addresses")
continue
case GRPCLB:
b.rbs = append(b.rbs, remoteBalancerInfo{
addr: update.Addr,
name: md.ServerName,
})
default:
grpclog.Printf("Received unknow address type %d", md.AddrType)
continue
}
case naming.Delete:
for i, v := range b.rbs {
if update.Addr == v.addr {
copy(b.rbs[i:], b.rbs[i+1:])
b.rbs = b.rbs[:len(b.rbs)-1]
break
}
}
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
// TODO: Fall back to the basic round-robin load balancing if the resulting address is
// not a load balancer.
select {
case <-ch:
default:
}
ch <- b.rbs
return nil
}
func (b *balancer) serverListExpire(seq int) {
b.mu.Lock()
defer b.mu.Unlock()
// TODO: gRPC interanls do not clear the connections when the server list is stale.
// This means RPCs will keep using the existing server list until b receives new
// server list even though the list is expired. Revisit this behavior later.
if b.done || seq < b.seq {
return
}
b.next = 0
b.addrs = nil
// Ask grpc internals to close all the corresponding connections.
b.addrCh <- nil
}
func convertDuration(d *lbpb.Duration) time.Duration {
if d == nil {
return 0
}
return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond
}
func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
if l == nil {
return
}
servers := l.GetServers()
expiration := convertDuration(l.GetExpirationInterval())
var (
sl []*grpclbAddrInfo
addrs []Address
)
for _, s := range servers {
md := metadata.Pairs("lb-token", s.LoadBalanceToken)
addr := Address{
Addr: fmt.Sprintf("%s:%d", net.IP(s.IpAddress), s.Port),
Metadata: &md,
}
sl = append(sl, &grpclbAddrInfo{
addr: addr,
dropForRateLimiting: s.DropForRateLimiting,
dropForLoadBalancing: s.DropForLoadBalancing,
})
addrs = append(addrs, addr)
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done || seq < b.seq {
return
}
if len(sl) > 0 {
// reset b.next to 0 when replacing the server list.
b.next = 0
b.addrs = sl
b.addrCh <- addrs
if b.expTimer != nil {
b.expTimer.Stop()
b.expTimer = nil
}
if expiration > 0 {
b.expTimer = time.AfterFunc(expiration, func() {
b.serverListExpire(seq)
})
}
}
return
}
func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-done:
return
}
b.mu.Lock()
stats := b.clientStats
b.clientStats = lbpb.ClientStats{} // Clear the stats.
b.mu.Unlock()
t := time.Now()
stats.Timestamp = &lbpb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
if err := s.Send(&lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
ClientStats: &stats,
},
}); err != nil {
grpclog.Printf("grpclb: failed to send load report: %v", err)
return
}
}
}
func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbc.BalanceLoad(ctx)
if err != nil {
grpclog.Printf("grpclb: failed to perform RPC to the remote balancer %v", err)
return
}
b.mu.Lock()
if b.done {
b.mu.Unlock()
return
}
b.mu.Unlock()
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
InitialRequest: &lbpb.InitialLoadBalanceRequest{
Name: b.target,
},
},
}
if err := stream.Send(initReq); err != nil {
grpclog.Printf("grpclb: failed to send init request: %v", err)
// TODO: backoff on retry?
return true
}
reply, err := stream.Recv()
if err != nil {
grpclog.Printf("grpclb: failed to recv init response: %v", err)
// TODO: backoff on retry?
return true
}
initResp := reply.GetInitialResponse()
if initResp == nil {
grpclog.Println("grpclb: reply from remote balancer did not include initial response.")
return
}
// TODO: Support delegation.
if initResp.LoadBalancerDelegate != "" {
// delegation
grpclog.Println("TODO: Delegation is not supported yet.")
return
}
streamDone := make(chan struct{})
defer close(streamDone)
b.mu.Lock()
b.clientStats = lbpb.ClientStats{} // Clear client stats.
b.mu.Unlock()
if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
go b.sendLoadReport(stream, d, streamDone)
}
// Retrieve the server list.
for {
reply, err := stream.Recv()
if err != nil {
grpclog.Printf("grpclb: failed to recv server list: %v", err)
break
}
b.mu.Lock()
if b.done || seq < b.seq {
b.mu.Unlock()
return
}
b.seq++ // tick when receiving a new list of servers.
seq = b.seq
b.mu.Unlock()
if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList, seq)
}
}
return true
}
func (b *balancer) Start(target string, config BalancerConfig) error {
b.rand = rand.New(rand.NewSource(time.Now().Unix()))
// TODO: Fall back to the basic direct connection if there is no name resolver.
if b.r == nil {
return errors.New("there is no name resolver installed")
}
b.target = target
b.mu.Lock()
if b.done {
b.mu.Unlock()
return ErrClientConnClosing
}
b.addrCh = make(chan []Address)
w, err := b.r.Resolve(target)
if err != nil {
b.mu.Unlock()
grpclog.Printf("grpclb: failed to resolve address: %v, err: %v", target, err)
return err
}
b.w = w
b.mu.Unlock()
balancerAddrsCh := make(chan []remoteBalancerInfo, 1)
// Spawn a goroutine to monitor the name resolution of remote load balancer.
go func() {
for {
if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil {
grpclog.Printf("grpclb: the naming watcher stops working due to %v.\n", err)
close(balancerAddrsCh)
return
}
}
}()
// Spawn a goroutine to talk to the remote load balancer.
go func() {
var (
cc *ClientConn
// ccError is closed when there is an error in the current cc.
// A new rb should be picked from rbs and connected.
ccError chan struct{}
rb *remoteBalancerInfo
rbs []remoteBalancerInfo
rbIdx int
)
defer func() {
if ccError != nil {
select {
case <-ccError:
default:
close(ccError)
}
}
if cc != nil {
cc.Close()
}
}()
for {
var ok bool
select {
case rbs, ok = <-balancerAddrsCh:
if !ok {
return
}
foundIdx := -1
if rb != nil {
for i, trb := range rbs {
if trb == *rb {
foundIdx = i
break
}
}
}
if foundIdx >= 0 {
if foundIdx >= 1 {
// Move the address in use to the beginning of the list.
b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0]
rbIdx = 0
}
continue // If found, don't dial new cc.
} else if len(rbs) > 0 {
// Pick a random one from the list, instead of always using the first one.
if l := len(rbs); l > 1 && rb != nil {
tmpIdx := b.rand.Intn(l - 1)
b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
}
rbIdx = 0
rb = &rbs[0]
} else {
// foundIdx < 0 && len(rbs) <= 0.
rb = nil
}
case <-ccError:
ccError = nil
if rbIdx < len(rbs)-1 {
rbIdx++
rb = &rbs[rbIdx]
} else {
rb = nil
}
}
if rb == nil {
continue
}
if cc != nil {
cc.Close()
}
// Talk to the remote load balancer to get the server list.
var (
err error
dopts []DialOption
)
if creds := config.DialCreds; creds != nil {
if rb.name != "" {
if err := creds.OverrideServerName(rb.name); err != nil {
grpclog.Printf("grpclb: failed to override the server name in the credentials: %v", err)
continue
}
}
dopts = append(dopts, WithTransportCredentials(creds))
} else {
dopts = append(dopts, WithInsecure())
}
if dialer := config.Dialer; dialer != nil {
// WithDialer takes a different type of function, so we instead use a special DialOption here.
dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer })
}
ccError = make(chan struct{})
cc, err = Dial(rb.addr, dopts...)
if err != nil {
grpclog.Printf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
close(ccError)
continue
}
b.mu.Lock()
b.seq++ // tick when getting a new balancer address
seq := b.seq
b.next = 0
b.mu.Unlock()
go func(cc *ClientConn, ccError chan struct{}) {
lbc := &loadBalancerClient{cc}
b.callRemoteBalancer(lbc, seq)
cc.Close()
select {
case <-ccError:
default:
close(ccError)
}
}(cc, ccError)
}
}()
return nil
}
func (b *balancer) down(addr Address, err error) {
b.mu.Lock()
defer b.mu.Unlock()
for _, a := range b.addrs {
if addr == a.addr {
a.connected = false
break
}
}
}
func (b *balancer) Up(addr Address) func(error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return nil
}
var cnt int
for _, a := range b.addrs {
if a.addr == addr {
if a.connected {
return nil
}
a.connected = true
}
if a.connected && !a.dropForRateLimiting && !a.dropForLoadBalancing {
cnt++
}
}
// addr is the only one which is connected. Notify the Get() callers who are blocking.
if cnt == 1 && b.waitCh != nil {
close(b.waitCh)
b.waitCh = nil
}
return func(err error) {
b.down(addr, err)
}
}
func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
var ch chan struct{}
b.mu.Lock()
if b.done {
b.mu.Unlock()
err = ErrClientConnClosing
return
}
seq := b.seq
defer func() {
if err != nil {
return
}
put = func() {
s, ok := rpcInfoFromContext(ctx)
if !ok {
return
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done || seq < b.seq {
return
}
b.clientStats.NumCallsFinished++
if !s.bytesSent {
b.clientStats.NumCallsFinishedWithClientFailedToSend++
} else if s.bytesReceived {
b.clientStats.NumCallsFinishedKnownReceived++
}
}
}()
b.clientStats.NumCallsStarted++
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
}
next := b.next
for {
a := b.addrs[next]
next = (next + 1) % len(b.addrs)
if a.connected {
if !a.dropForRateLimiting && !a.dropForLoadBalancing {
addr = a.addr
b.next = next
b.mu.Unlock()
return
}
if !opts.BlockingWait {
b.next = next
if a.dropForLoadBalancing {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
} else if a.dropForRateLimiting {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
}
b.mu.Unlock()
err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr)
return
}
}
if next == b.next {
// Has iterated all the possible address but none is connected.
break
}
}
}
if !opts.BlockingWait {
if len(b.addrs) == 0 {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = Errorf(codes.Unavailable, "there is no address available")
return
}
// Returns the next addr on b.addrs for a failfast RPC.
addr = b.addrs[b.next].addr
b.next++
b.mu.Unlock()
return
}
// Wait on b.waitCh for non-failfast RPCs.
if b.waitCh == nil {
ch = make(chan struct{})
b.waitCh = ch
} else {
ch = b.waitCh
}
b.mu.Unlock()
for {
select {
case <-ctx.Done():
b.mu.Lock()
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = ctx.Err()
return
case <-ch:
b.mu.Lock()
if b.done {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithClientFailedToSend++
b.mu.Unlock()
err = ErrClientConnClosing
return
}
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
}
next := b.next
for {
a := b.addrs[next]
next = (next + 1) % len(b.addrs)
if a.connected {
if !a.dropForRateLimiting && !a.dropForLoadBalancing {
addr = a.addr
b.next = next
b.mu.Unlock()
return
}
if !opts.BlockingWait {
b.next = next
if a.dropForLoadBalancing {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForLoadBalancing++
} else if a.dropForRateLimiting {
b.clientStats.NumCallsFinished++
b.clientStats.NumCallsFinishedWithDropForRateLimiting++
}
b.mu.Unlock()
err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr)
return
}
}
if next == b.next {
// Has iterated all the possible address but none is connected.
break
}
}
}
// The newly added addr got removed by Down() again.
if b.waitCh == nil {
ch = make(chan struct{})
b.waitCh = ch
} else {
ch = b.waitCh
}
b.mu.Unlock()
}
}
}
func (b *balancer) Notify() <-chan []Address {
return b.addrCh
}
func (b *balancer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return errBalancerClosed
}
b.done = true
if b.expTimer != nil {
b.expTimer.Stop()
}
if b.waitCh != nil {
close(b.waitCh)
}
if b.addrCh != nil {
close(b.addrCh)
}
if b.w != nil {
b.w.Close()
}
return nil
}

View File

@ -42,15 +42,15 @@ type UnaryInvoker func(ctx context.Context, method string, req, reply interface{
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. invoker is the handler to complete the RPC // UnaryClientInterceptor intercepts the execution of a unary RPC on the client. invoker is the handler to complete the RPC
// and it is the responsibility of the interceptor to call it. // and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API. // This is an EXPERIMENTAL API.
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
// Streamer is called by StreamClientInterceptor to create a ClientStream. // Streamer is called by StreamClientInterceptor to create a ClientStream.
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O // StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it. // operations. streamer is the handler to create a ClientStream and it is the responsibility of the interceptor to call it.
// This is the EXPERIMENTAL API. // This is an EXPERIMENTAL API.
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
// UnaryServerInfo consists of various information about a unary RPC on // UnaryServerInfo consists of various information about a unary RPC on

145
vendor/google.golang.org/grpc/proxy.go generated vendored Normal file
View File

@ -0,0 +1,145 @@
/*
*
* Copyright 2017, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"golang.org/x/net/context"
)
var (
// errDisabled indicates that proxy is disabled for the address.
errDisabled = errors.New("proxy is disabled for the address")
// The following variable will be overwritten in the tests.
httpProxyFromEnvironment = http.ProxyFromEnvironment
)
func mapAddress(ctx context.Context, address string) (string, error) {
req := &http.Request{
URL: &url.URL{
Scheme: "https",
Host: address,
},
}
url, err := httpProxyFromEnvironment(req)
if err != nil {
return "", err
}
if url == nil {
return "", errDisabled
}
return url.Host, nil
}
// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
// It's possible that this reader reads more than what's need for the response and stores
// those bytes in the buffer.
// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the
// bytes in the buffer.
type bufConn struct {
net.Conn
r io.Reader
}
func (c *bufConn) Read(b []byte) (int, error) {
return c.r.Read(b)
}
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) {
defer func() {
if err != nil {
conn.Close()
}
}()
req := (&http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Header: map[string][]string{"User-Agent": {grpcUA}},
})
if err := sendHTTPRequest(ctx, req, conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
}
r := bufio.NewReader(conn)
resp, err := http.ReadResponse(r, req)
if err != nil {
return nil, fmt.Errorf("reading server HTTP response: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
dump, err := httputil.DumpResponse(resp, true)
if err != nil {
return nil, fmt.Errorf("failed to do connect handshake, status code: %s", resp.Status)
}
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
}
return &bufConn{Conn: conn, r: r}, nil
}
// newProxyDialer returns a dialer that connects to proxy first if necessary.
// The returned dialer checks if a proxy is necessary, dial to the proxy with the
// provided dialer, does HTTP CONNECT handshake and returns the connection.
func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (conn net.Conn, err error) {
var skipHandshake bool
newAddr, err := mapAddress(ctx, addr)
if err != nil {
if err != errDisabled {
return nil, err
}
skipHandshake = true
newAddr = addr
}
conn, err = dialer(ctx, newAddr)
if err != nil {
return
}
if !skipHandshake {
conn, err = doHTTPConnectHandshake(ctx, conn, addr)
}
return
}
}

View File

@ -41,10 +41,12 @@ import (
"io/ioutil" "io/ioutil"
"math" "math"
"os" "os"
"sync"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -60,16 +62,24 @@ type Compressor interface {
Type() string Type() string
} }
// NewGZIPCompressor creates a Compressor based on GZIP. type gzipCompressor struct {
func NewGZIPCompressor() Compressor { pool sync.Pool
return &gzipCompressor{}
} }
type gzipCompressor struct { // NewGZIPCompressor creates a Compressor based on GZIP.
func NewGZIPCompressor() Compressor {
return &gzipCompressor{
pool: sync.Pool{
New: func() interface{} {
return gzip.NewWriter(ioutil.Discard)
},
},
}
} }
func (c *gzipCompressor) Do(w io.Writer, p []byte) error { func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
z := gzip.NewWriter(w) z := c.pool.Get().(*gzip.Writer)
z.Reset(w)
if _, err := z.Write(p); err != nil { if _, err := z.Write(p); err != nil {
return err return err
} }
@ -89,6 +99,7 @@ type Decompressor interface {
} }
type gzipDecompressor struct { type gzipDecompressor struct {
pool sync.Pool
} }
// NewGZIPDecompressor creates a Decompressor based on GZIP. // NewGZIPDecompressor creates a Decompressor based on GZIP.
@ -97,11 +108,26 @@ func NewGZIPDecompressor() Decompressor {
} }
func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) { func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
z, err := gzip.NewReader(r) var z *gzip.Reader
if err != nil { switch maybeZ := d.pool.Get().(type) {
return nil, err case nil:
newZ, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
z = newZ
case *gzip.Reader:
z = maybeZ
if err := z.Reset(r); err != nil {
d.pool.Put(z)
return nil, err
}
} }
defer z.Close()
defer func() {
z.Close()
d.pool.Put(z)
}()
return ioutil.ReadAll(z) return ioutil.ReadAll(z)
} }
@ -111,11 +137,14 @@ func (d *gzipDecompressor) Type() string {
// callInfo contains all related configuration and information about an RPC. // callInfo contains all related configuration and information about an RPC.
type callInfo struct { type callInfo struct {
failFast bool failFast bool
headerMD metadata.MD headerMD metadata.MD
trailerMD metadata.MD trailerMD metadata.MD
peer *peer.Peer peer *peer.Peer
traceInfo traceInfo // in trace.go traceInfo traceInfo // in trace.go
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
} }
var defaultCallInfo = callInfo{failFast: true} var defaultCallInfo = callInfo{failFast: true}
@ -132,6 +161,14 @@ type CallOption interface {
after(*callInfo) after(*callInfo)
} }
// EmptyCallOption does not alter the Call configuration.
// It can be embedded in another structure to carry satellite data for use
// by interceptors.
type EmptyCallOption struct{}
func (EmptyCallOption) before(*callInfo) error { return nil }
func (EmptyCallOption) after(*callInfo) {}
type beforeCall func(c *callInfo) error type beforeCall func(c *callInfo) error
func (o beforeCall) before(c *callInfo) error { return o(c) } func (o beforeCall) before(c *callInfo) error { return o(c) }
@ -173,7 +210,8 @@ func Peer(peer *peer.Peer) CallOption {
// immediately. Otherwise, the RPC client will block the call until a // immediately. Otherwise, the RPC client will block the call until a
// connection is available (or the call is canceled or times out) and will retry // connection is available (or the call is canceled or times out) and will retry
// the call if it fails due to a transient error. Please refer to // the call if it fails due to a transient error. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md. Note: failFast is default to true. // https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md.
// Note: failFast is default to true.
func FailFast(failFast bool) CallOption { func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error { return beforeCall(func(c *callInfo) error {
c.failFast = failFast c.failFast = failFast
@ -181,6 +219,31 @@ func FailFast(failFast bool) CallOption {
}) })
} }
// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive.
func MaxCallRecvMsgSize(s int) CallOption {
return beforeCall(func(o *callInfo) error {
o.maxReceiveMessageSize = &s
return nil
})
}
// MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send.
func MaxCallSendMsgSize(s int) CallOption {
return beforeCall(func(o *callInfo) error {
o.maxSendMessageSize = &s
return nil
})
}
// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call.
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
return beforeCall(func(c *callInfo) error {
c.creds = creds
return nil
})
}
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8
@ -214,8 +277,8 @@ type parser struct {
// No other error values or types must be returned, which also means // No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible // that the underlying io.Reader must not return an incompatible
// error. // error.
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) { func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil { if _, err := p.r.Read(p.header[:]); err != nil {
return 0, nil, err return 0, nil, err
} }
@ -225,13 +288,13 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro
if length == 0 { if length == 0 {
return pf, nil, nil return pf, nil, nil
} }
if length > uint32(maxMsgSize) { if length > uint32(maxReceiveMessageSize) {
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize) return 0, nil, Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
} }
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message: // of making it for each message:
msg = make([]byte, int(length)) msg = make([]byte, int(length))
if _, err := io.ReadFull(p.r, msg); err != nil { if _, err := p.r.Read(msg); err != nil {
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
@ -269,7 +332,7 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl
length = uint(len(b)) length = uint(len(b))
} }
if length > math.MaxUint32 { if length > math.MaxUint32 {
return nil, Errorf(codes.InvalidArgument, "grpc: message too large (%d bytes)", length) return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length)
} }
const ( const (
@ -310,8 +373,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil return nil
} }
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error { func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error {
pf, d, err := p.recvMsg(maxMsgSize) pf, d, err := p.recvMsg(maxReceiveMessageSize)
if err != nil { if err != nil {
return err return err
} }
@ -327,10 +390,10 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
if len(d) > maxMsgSize { if len(d) > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java // TODO: Revisit the error code. Currently keep it consistent with java
// implementation. // implementation.
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize) return Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
} }
if err := c.Unmarshal(d, m); err != nil { if err := c.Unmarshal(d, m); err != nil {
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
@ -345,6 +408,29 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
return nil return nil
} }
type rpcInfo struct {
bytesSent bool
bytesReceived bool
}
type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{})
}
func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {
s, ok = ctx.Value(rpcInfoContextKey{}).(*rpcInfo)
return
}
func updateRPCInfoInContext(ctx context.Context, s rpcInfo) {
if ss, ok := rpcInfoFromContext(ctx); ok {
*ss = s
}
return
}
// Code returns the error code for err if it was produced by the rpc system. // Code returns the error code for err if it was produced by the rpc system.
// Otherwise, it returns codes.Unknown. // Otherwise, it returns codes.Unknown.
// //
@ -433,24 +519,22 @@ type MethodConfig struct {
// WaitForReady indicates whether RPCs sent to this method should wait until // WaitForReady indicates whether RPCs sent to this method should wait until
// the connection is ready by default (!failfast). The value specified via the // the connection is ready by default (!failfast). The value specified via the
// gRPC client API will override the value set here. // gRPC client API will override the value set here.
WaitForReady bool WaitForReady *bool
// Timeout is the default timeout for RPCs sent to this method. The actual // Timeout is the default timeout for RPCs sent to this method. The actual
// deadline used will be the minimum of the value specified here and the value // deadline used will be the minimum of the value specified here and the value
// set by the application via the gRPC client API. If either one is not set, // set by the application via the gRPC client API. If either one is not set,
// then the other will be used. If neither is set, then the RPC has no deadline. // then the other will be used. If neither is set, then the RPC has no deadline.
Timeout time.Duration Timeout *time.Duration
// MaxReqSize is the maximum allowed payload size for an individual request in a // MaxReqSize is the maximum allowed payload size for an individual request in a
// stream (client->server) in bytes. The size which is measured is the serialized // stream (client->server) in bytes. The size which is measured is the serialized
// payload after per-message compression (but before stream compression) in bytes. // payload after per-message compression (but before stream compression) in bytes.
// The actual value used is the minumum of the value specified here and the value set // The actual value used is the minumum of the value specified here and the value set
// by the application via the gRPC client API. If either one is not set, then the other // by the application via the gRPC client API. If either one is not set, then the other
// will be used. If neither is set, then the built-in default is used. // will be used. If neither is set, then the built-in default is used.
// TODO: support this. MaxReqSize *int
MaxReqSize uint32
// MaxRespSize is the maximum allowed payload size for an individual response in a // MaxRespSize is the maximum allowed payload size for an individual response in a
// stream (server->client) in bytes. // stream (server->client) in bytes.
// TODO: support this. MaxRespSize *int
MaxRespSize uint32
} }
// ServiceConfig is provided by the service provider and contains parameters for how // ServiceConfig is provided by the service provider and contains parameters for how
@ -461,9 +545,32 @@ type ServiceConfig struct {
// via grpc.WithBalancer will override this. // via grpc.WithBalancer will override this.
LB Balancer LB Balancer
// Methods contains a map for the methods in this service. // Methods contains a map for the methods in this service.
// If there is an exact match for a method (i.e. /service/method) in the map, use the corresponding MethodConfig.
// If there's no exact match, look for the default config for the service (/service/) and use the corresponding MethodConfig if it exists.
// Otherwise, the method has no MethodConfig to use.
Methods map[string]MethodConfig Methods map[string]MethodConfig
} }
func min(a, b *int) *int {
if *a < *b {
return a
}
return b
}
func getMaxSize(mcMax, doptMax *int, defaultVal int) *int {
if mcMax == nil && doptMax == nil {
return &defaultVal
}
if mcMax != nil && doptMax != nil {
return min(mcMax, doptMax)
}
if mcMax != nil {
return mcMax
}
return doptMax
}
// SupportPackageIsVersion4 is referenced from generated protocol buffer files // SupportPackageIsVersion4 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the grpc package. // to assert that that code is compatible with this version of the grpc package.
// //
@ -473,4 +580,6 @@ type ServiceConfig struct {
const SupportPackageIsVersion4 = true const SupportPackageIsVersion4 = true
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.3.0-dev" const Version = "1.4.0-dev"
const grpcUA = "grpc-go/" + Version

View File

@ -61,6 +61,11 @@ import (
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
const (
defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
defaultServerMaxSendMessageSize = 1024 * 1024 * 4
)
type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error) type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error)
// MethodDesc represents an RPC service's method specification. // MethodDesc represents an RPC service's method specification.
@ -107,27 +112,49 @@ type Server struct {
} }
type options struct { type options struct {
creds credentials.TransportCredentials creds credentials.TransportCredentials
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
maxMsgSize int unaryInt UnaryServerInterceptor
unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor
streamInt StreamServerInterceptor inTapHandle tap.ServerInHandle
inTapHandle tap.ServerInHandle statsHandler stats.Handler
statsHandler stats.Handler maxConcurrentStreams uint32
maxConcurrentStreams uint32 maxReceiveMessageSize int
useHandlerImpl bool // use http.Handler-based server maxSendMessageSize int
unknownStreamDesc *StreamDesc useHandlerImpl bool // use http.Handler-based server
keepaliveParams keepalive.ServerParameters unknownStreamDesc *StreamDesc
keepalivePolicy keepalive.EnforcementPolicy keepaliveParams keepalive.ServerParameters
keepalivePolicy keepalive.EnforcementPolicy
initialWindowSize int32
initialConnWindowSize int32
} }
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit var defaultServerOptions = options{
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
maxSendMessageSize: defaultServerMaxSendMessageSize,
}
// A ServerOption sets options. // A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
type ServerOption func(*options) type ServerOption func(*options)
// InitialWindowSize returns a ServerOption that sets window size for stream.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func InitialWindowSize(s int32) ServerOption {
return func(o *options) {
o.initialWindowSize = s
}
}
// InitialConnWindowSize returns a ServerOption that sets window size for a connection.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func InitialConnWindowSize(s int32) ServerOption {
return func(o *options) {
o.initialConnWindowSize = s
}
}
// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. // KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
return func(o *options) { return func(o *options) {
@ -163,11 +190,25 @@ func RPCDecompressor(dc Decompressor) ServerOption {
} }
} }
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
// If this is not set, gRPC uses the default 4MB. // If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead.
func MaxMsgSize(m int) ServerOption { func MaxMsgSize(m int) ServerOption {
return MaxRecvMsgSize(m)
}
// MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
// If this is not set, gRPC uses the default 4MB.
func MaxRecvMsgSize(m int) ServerOption {
return func(o *options) { return func(o *options) {
o.maxMsgSize = m o.maxReceiveMessageSize = m
}
}
// MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send.
// If this is not set, gRPC uses the default 4MB.
func MaxSendMsgSize(m int) ServerOption {
return func(o *options) {
o.maxSendMessageSize = m
} }
} }
@ -192,7 +233,7 @@ func Creds(c credentials.TransportCredentials) ServerOption {
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
return func(o *options) { return func(o *options) {
if o.unaryInt != nil { if o.unaryInt != nil {
panic("The unary server interceptor has been set.") panic("The unary server interceptor was already set and may not be reset.")
} }
o.unaryInt = i o.unaryInt = i
} }
@ -203,7 +244,7 @@ func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
func StreamInterceptor(i StreamServerInterceptor) ServerOption { func StreamInterceptor(i StreamServerInterceptor) ServerOption {
return func(o *options) { return func(o *options) {
if o.streamInt != nil { if o.streamInt != nil {
panic("The stream server interceptor has been set.") panic("The stream server interceptor was already set and may not be reset.")
} }
o.streamInt = i o.streamInt = i
} }
@ -214,7 +255,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
func InTapHandle(h tap.ServerInHandle) ServerOption { func InTapHandle(h tap.ServerInHandle) ServerOption {
return func(o *options) { return func(o *options) {
if o.inTapHandle != nil { if o.inTapHandle != nil {
panic("The tap handle has been set.") panic("The tap handle was already set and may not be reset.")
} }
o.inTapHandle = h o.inTapHandle = h
} }
@ -229,7 +270,7 @@ func StatsHandler(h stats.Handler) ServerOption {
// UnknownServiceHandler returns a ServerOption that allows for adding a custom // UnknownServiceHandler returns a ServerOption that allows for adding a custom
// unknown service handler. The provided method is a bidi-streaming RPC service // unknown service handler. The provided method is a bidi-streaming RPC service
// handler that will be invoked instead of returning the the "unimplemented" gRPC // handler that will be invoked instead of returning the "unimplemented" gRPC
// error whenever a request is received for an unregistered service or method. // error whenever a request is received for an unregistered service or method.
// The handling function has full access to the Context of the request and the // The handling function has full access to the Context of the request and the
// stream, and the invocation passes through interceptors. // stream, and the invocation passes through interceptors.
@ -248,8 +289,7 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
var opts options opts := defaultServerOptions
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
@ -288,8 +328,8 @@ func (s *Server) errorf(format string, a ...interface{}) {
} }
} }
// RegisterService register a service and its implementation to the gRPC // RegisterService registers a service and its implementation to the gRPC
// server. Called from the IDL generated code. This must be called before // server. It is called from the IDL generated code. This must be called before
// invoking Serve. // invoking Serve.
func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) { func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
ht := reflect.TypeOf(sd.HandlerType).Elem() ht := reflect.TypeOf(sd.HandlerType).Elem()
@ -334,7 +374,7 @@ type MethodInfo struct {
IsServerStream bool IsServerStream bool
} }
// ServiceInfo contains unary RPC method info, streaming RPC methid info and metadata for a service. // ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service.
type ServiceInfo struct { type ServiceInfo struct {
Methods []MethodInfo Methods []MethodInfo
// Metadata is the metadata specified in ServiceDesc when registering service. // Metadata is the metadata specified in ServiceDesc when registering service.
@ -427,10 +467,12 @@ func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock() s.mu.Lock()
s.printf("Accept error: %v; retrying in %v", err, tempDelay) s.printf("Accept error: %v; retrying in %v", err, tempDelay)
s.mu.Unlock() s.mu.Unlock()
timer := time.NewTimer(tempDelay)
select { select {
case <-time.After(tempDelay): case <-timer.C:
case <-s.ctx.Done(): case <-s.ctx.Done():
} }
timer.Stop()
continue continue
} }
s.mu.Lock() s.mu.Lock()
@ -483,12 +525,14 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
// transport.NewServerTransport). // transport.NewServerTransport).
func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
config := &transport.ServerConfig{ config := &transport.ServerConfig{
MaxStreams: s.opts.maxConcurrentStreams, MaxStreams: s.opts.maxConcurrentStreams,
AuthInfo: authInfo, AuthInfo: authInfo,
InTapHandle: s.opts.inTapHandle, InTapHandle: s.opts.inTapHandle,
StatsHandler: s.opts.statsHandler, StatsHandler: s.opts.statsHandler,
KeepaliveParams: s.opts.keepaliveParams, KeepaliveParams: s.opts.keepaliveParams,
KeepalivePolicy: s.opts.keepalivePolicy, KeepalivePolicy: s.opts.keepalivePolicy,
InitialWindowSize: s.opts.initialWindowSize,
InitialConnWindowSize: s.opts.initialConnWindowSize,
} }
st, err := transport.NewServerTransport("http2", c, config) st, err := transport.NewServerTransport("http2", c, config)
if err != nil { if err != nil {
@ -629,6 +673,9 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
// the optimal option. // the optimal option.
grpclog.Fatalf("grpc: Server failed to encode response %v", err) grpclog.Fatalf("grpc: Server failed to encode response %v", err)
} }
if len(p) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize)
}
err = t.Write(stream, p, opts) err = t.Write(stream, p, opts)
if err == nil && outPayload != nil { if err == nil && outPayload != nil {
outPayload.SentTime = time.Now() outPayload.SentTime = time.Now()
@ -644,9 +691,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
BeginTime: time.Now(), BeginTime: time.Now(),
} }
sh.HandleRPC(stream.Context(), begin) sh.HandleRPC(stream.Context(), begin)
} defer func() {
defer func() {
if sh != nil {
end := &stats.End{ end := &stats.End{
EndTime: time.Now(), EndTime: time.Now(),
} }
@ -654,8 +699,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
end.Error = toRPCErr(err) end.Error = toRPCErr(err)
} }
sh.HandleRPC(stream.Context(), end) sh.HandleRPC(stream.Context(), end)
} }()
}() }
if trInfo != nil { if trInfo != nil {
defer trInfo.tr.Finish() defer trInfo.tr.Finish()
trInfo.firstLine.client = false trInfo.firstLine.client = false
@ -672,139 +717,137 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
stream.SetSendCompress(s.opts.cp.Type()) stream.SetSendCompress(s.opts.cp.Type())
} }
p := &parser{r: stream} p := &parser{r: stream}
for { // TODO: delete pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
}
if err == io.ErrUnexpectedEOF {
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
}
if err != nil {
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
} else {
switch st := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
}
}
return err
}
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
return err
}
if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
// TODO checkRecvPayload always return RPC error. Add a return here if necessary.
}
var inPayload *stats.InPayload
if sh != nil {
inPayload = &stats.InPayload{
RecvTime: time.Now(),
}
}
df := func(v interface{}) error {
if inPayload != nil {
inPayload.WireLength = len(req)
}
if pf == compressionMade {
var err error
req, err = s.opts.dc.Do(bytes.NewReader(req))
if err != nil {
return Errorf(codes.Internal, err.Error())
}
}
if len(req) > s.opts.maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
if inPayload != nil {
inPayload.Payload = v
inPayload.Data = req
inPayload.Length = len(req)
sh.HandleRPC(stream.Context(), inPayload)
}
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
}
return nil
}
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
// Convert appErr if it is not a grpc status error.
appErr = status.Error(convertCode(appErr), appErr.Error())
appStatus, _ = status.FromError(appErr)
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
trInfo.tr.SetError()
}
if e := t.WriteStatus(stream, appStatus); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e)
}
return appErr
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer("OK"), false)
}
opts := &transport.Options{
Last: true,
Delay: false,
}
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
} }
if err == io.ErrUnexpectedEOF { if s, ok := status.FromError(err); ok {
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) if e := t.WriteStatus(stream, s); e != nil {
}
if err != nil {
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
} else {
switch st := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
}
}
return err
}
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
return err
}
if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
// TODO checkRecvPayload always return RPC error. Add a return here if necessary.
}
var inPayload *stats.InPayload
if sh != nil {
inPayload = &stats.InPayload{
RecvTime: time.Now(),
}
}
df := func(v interface{}) error {
if inPayload != nil {
inPayload.WireLength = len(req)
}
if pf == compressionMade {
var err error
req, err = s.opts.dc.Do(bytes.NewReader(req))
if err != nil {
return Errorf(codes.Internal, err.Error())
}
}
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
return status.Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
if inPayload != nil {
inPayload.Payload = v
inPayload.Data = req
inPayload.Length = len(req)
sh.HandleRPC(stream.Context(), inPayload)
}
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
}
return nil
}
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
// Convert appErr if it is not a grpc status error.
appErr = status.Error(convertCode(appErr), appErr.Error())
appStatus, _ = status.FromError(appErr)
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
trInfo.tr.SetError()
}
if e := t.WriteStatus(stream, appStatus); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e)
} }
return appErr } else {
} switch st := err.(type) {
if trInfo != nil { case transport.ConnectionError:
trInfo.tr.LazyLog(stringer("OK"), false) // Nothing to do here.
} case transport.StreamError:
opts := &transport.Options{ if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
Last: true, grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
Delay: false,
}
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
}
if s, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, s); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e)
}
} else {
switch st := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
} }
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
} }
return err
} }
if trInfo != nil { return err
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
// TODO: Should we be logging if writing status failed here, like above?
// Should the logging be in WriteStatus? Should we ignore the WriteStatus
// error or allow the stats handler to see it?
return t.WriteStatus(stream, status.New(codes.OK, ""))
} }
if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
}
// TODO: Should we be logging if writing status failed here, like above?
// Should the logging be in WriteStatus? Should we ignore the WriteStatus
// error or allow the stats handler to see it?
return t.WriteStatus(stream, status.New(codes.OK, ""))
} }
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
@ -814,9 +857,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
BeginTime: time.Now(), BeginTime: time.Now(),
} }
sh.HandleRPC(stream.Context(), begin) sh.HandleRPC(stream.Context(), begin)
} defer func() {
defer func() {
if sh != nil {
end := &stats.End{ end := &stats.End{
EndTime: time.Now(), EndTime: time.Now(),
} }
@ -824,21 +865,22 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
end.Error = toRPCErr(err) end.Error = toRPCErr(err)
} }
sh.HandleRPC(stream.Context(), end) sh.HandleRPC(stream.Context(), end)
} }()
}() }
if s.opts.cp != nil { if s.opts.cp != nil {
stream.SetSendCompress(s.opts.cp.Type()) stream.SetSendCompress(s.opts.cp.Type())
} }
ss := &serverStream{ ss := &serverStream{
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream}, p: &parser{r: stream},
codec: s.opts.codec, codec: s.opts.codec,
cp: s.opts.cp, cp: s.opts.cp,
dc: s.opts.dc, dc: s.opts.dc,
maxMsgSize: s.opts.maxMsgSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
trInfo: trInfo, maxSendMessageSize: s.opts.maxSendMessageSize,
statsHandler: sh, trInfo: trInfo,
statsHandler: sh,
} }
if ss.cp != nil { if ss.cp != nil {
ss.cbuf = new(bytes.Buffer) ss.cbuf = new(bytes.Buffer)
@ -913,7 +955,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.SetError() trInfo.tr.SetError()
} }
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil { if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError() trInfo.tr.SetError()
@ -1011,8 +1053,9 @@ func (s *Server) Stop() {
s.mu.Unlock() s.mu.Unlock()
} }
// GracefulStop stops the gRPC server gracefully. It stops the server to accept new // GracefulStop stops the gRPC server gracefully. It stops the server from
// connections and RPCs and blocks until all the pending RPCs are finished. // accepting new connections and RPCs and blocks until all the pending RPCs are
// finished.
func (s *Server) GracefulStop() { func (s *Server) GracefulStop() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()

View File

@ -113,17 +113,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
cancel context.CancelFunc cancel context.CancelFunc
) )
c := defaultCallInfo c := defaultCallInfo
if mc, ok := cc.getMethodConfig(method); ok { mc := cc.GetMethodConfig(method)
c.failFast = !mc.WaitForReady if mc.WaitForReady != nil {
if mc.Timeout > 0 { c.failFast = !*mc.WaitForReady
ctx, cancel = context.WithTimeout(ctx, mc.Timeout)
}
} }
if mc.Timeout != nil {
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
}
opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(&c); err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
} }
c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
@ -132,6 +139,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
if c.creds != nil {
callHdr.Creds = c.creds
}
var trInfo traceInfo var trInfo traceInfo
if EnableTracing { if EnableTracing {
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
@ -151,26 +161,27 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
}() }()
} }
ctx = newContextWithRPCInfo(ctx)
sh := cc.dopts.copts.StatsHandler sh := cc.dopts.copts.StatsHandler
if sh != nil { if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
begin := &stats.Begin{ begin := &stats.Begin{
Client: true, Client: true,
BeginTime: time.Now(), BeginTime: time.Now(),
FailFast: c.failFast, FailFast: c.failFast,
} }
sh.HandleRPC(ctx, begin) sh.HandleRPC(ctx, begin)
} defer func() {
defer func() { if err != nil {
if err != nil && sh != nil { // Only handle end stats if err != nil.
// Only handle end stats if err != nil. end := &stats.End{
end := &stats.End{ Client: true,
Client: true, Error: err,
Error: err, }
sh.HandleRPC(ctx, end)
} }
sh.HandleRPC(ctx, end) }()
} }
}()
gopts := BalancerGetOptions{ gopts := BalancerGetOptions{
BlockingWait: !c.failFast, BlockingWait: !c.failFast,
} }
@ -193,14 +204,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s, err = t.NewStream(ctx, callHdr) s, err = t.NewStream(ctx, callHdr)
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); ok && put != nil {
// If error is connection error, transport was sending data on wire,
// and we are not sure if anything has been sent on wire.
// If error is not connection error, we are sure nothing has been sent.
updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
}
if put != nil { if put != nil {
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
if c.failFast {
return nil, toRPCErr(err)
}
continue continue
} }
return nil, toRPCErr(err) return nil, toRPCErr(err)
@ -208,14 +222,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break break
} }
cs := &clientStream{ cs := &clientStream{
opts: opts, opts: opts,
c: c, c: c,
desc: desc, desc: desc,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cc.dopts.cp, cp: cc.dopts.cp,
dc: cc.dopts.dc, dc: cc.dopts.dc,
maxMsgSize: cc.dopts.maxMsgSize, cancel: cancel,
cancel: cancel,
put: put, put: put,
t: t, t: t,
@ -237,6 +250,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
select { select {
case <-t.Error(): case <-t.Error():
// Incur transport error, simply exit. // Incur transport error, simply exit.
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
cs.closeTransportStream(ErrClientConnClosing)
case <-s.Done(): case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending // TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution. // I/O, which is probably not the optimal solution.
@ -256,18 +272,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption opts []CallOption
c callInfo c callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
desc *StreamDesc desc *StreamDesc
codec Codec codec Codec
cp Compressor cp Compressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
dc Decompressor dc Decompressor
maxMsgSize int cancel context.CancelFunc
cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
@ -351,6 +366,12 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil { if err != nil {
return Errorf(codes.Internal, "grpc: %v", err) return Errorf(codes.Internal, "grpc: %v", err)
} }
if cs.c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(out) > *cs.c.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize)
}
err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
if err == nil && outPayload != nil { if err == nil && outPayload != nil {
outPayload.SentTime = time.Now() outPayload.SentTime = time.Now()
@ -366,7 +387,10 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload) if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -389,7 +413,10 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil) if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -460,6 +487,10 @@ func (cs *clientStream) finish(err error) {
o.after(&cs.c) o.after(&cs.c)
} }
if cs.put != nil { if cs.put != nil {
updateRPCInfoInContext(cs.s.Context(), rpcInfo{
bytesSent: cs.s.BytesSent(),
bytesReceived: cs.s.BytesReceived(),
})
cs.put() cs.put()
cs.put = nil cs.put = nil
} }
@ -510,15 +541,16 @@ type ServerStream interface {
// serverStream implements a server side Stream. // serverStream implements a server side Stream.
type serverStream struct { type serverStream struct {
t transport.ServerTransport t transport.ServerTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
maxMsgSize int maxReceiveMessageSize int
trInfo *traceInfo maxSendMessageSize int
trInfo *traceInfo
statsHandler stats.Handler statsHandler stats.Handler
@ -577,6 +609,9 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
err = Errorf(codes.Internal, "grpc: %v", err) err = Errorf(codes.Internal, "grpc: %v", err)
return err return err
} }
if len(out) > ss.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize)
}
if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
@ -606,7 +641,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
inPayload = &stats.InPayload{} inPayload = &stats.InPayload{}
} }
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil { if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil {
if err == io.EOF { if err == io.EOF {
return err return err
} }

6
vendor/vendor.json vendored
View File

@ -1579,10 +1579,10 @@
"revisionTime": "2017-04-04T13:20:09Z" "revisionTime": "2017-04-04T13:20:09Z"
}, },
{ {
"checksumSHA1": "0hmkKgzN2Efaa/45M+eK+RMAh20=", "checksumSHA1": "/rSFejOYzXTDxhpcW/vN7EFU2Qs=",
"path": "google.golang.org/grpc", "path": "google.golang.org/grpc",
"revision": "8a6eb0f6e96a246c7c655b0207b62255042f6fbd", "revision": "6dff7c5f333b5331089ccc6451fc842664f3be6a",
"revisionTime": "2017-04-13T20:41:35Z" "revisionTime": "2017-05-23T18:39:15Z"
}, },
{ {
"checksumSHA1": "08icuA15HRkdYCt6H+Cs90RPQsY=", "checksumSHA1": "08icuA15HRkdYCt6H+Cs90RPQsY=",